1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Make style

This commit is contained in:
Patrick von Platen
2022-06-27 15:59:04 +00:00
parent 932ce05d97
commit 4261c3aadf
25 changed files with 451 additions and 290 deletions

View File

@@ -34,13 +34,9 @@ autogenerate_code: deps_table_update
# Check that the repo is in a good state
repo-consistency:
python utils/check_copies.py
python utils/check_table.py
python utils/check_dummies.py
python utils/check_repo.py
python utils/check_inits.py
python utils/check_config_docstrings.py
python utils/tests_fetcher.py --sanity_check
# this target runs checks on all files
@@ -48,14 +44,13 @@ quality:
black --check --preview $(check_dirs)
isort --check-only $(check_dirs)
flake8 $(check_dirs)
doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source
# Format source code automatically and check is there are any problems left that need manual fixing
extra_style_checks:
python utils/custom_init_isort.py
python utils/sort_auto_mappings.py
doc-builder style src/transformers docs/source --max_len 119 --path_to_docs docs/source
doc-builder style src/diffusers docs/source --max_len 119 --path_to_docs docs/source
# this target runs checks on all files and potentially modifies some of them
@@ -73,8 +68,6 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
fix-copies:
python utils/check_dummies.py --fix_and_overwrite
python utils/check_table.py --fix_and_overwrite
python utils/check_copies.py --fix_and_overwrite
# Run tests for the library

View File

@@ -47,12 +47,11 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
def init_git_repo(args, at_init: bool = False):
"""
Initializes a git repo in `args.hub_model_id`.
Args:
Initializes a git repo in `args.hub_model_id`.
at_init (`bool`, *optional*, defaults to `False`):
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
`True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
out.
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
"""
if args.local_rank not in [-1, 0]:
return
@@ -102,8 +101,8 @@ def push_to_hub(
**kwargs,
) -> str:
"""
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
Parameters:
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
commit_message (`str`, *optional*, defaults to `"End of training"`):
Message to commit while pushing.
blocking (`bool`, *optional*, defaults to `True`):
@@ -111,8 +110,8 @@ def push_to_hub(
kwargs:
Additional keyword arguments passed along to [`create_model_card`].
Returns:
The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
the commit and an object to track the progress of the commit if `blocking=True`
The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the
commit and an object to track the progress of the commit if `blocking=True`
"""
if args.hub_model_id is None:

View File

@@ -123,16 +123,16 @@ class ModelMixin(torch.nn.Module):
r"""
Base class for all models.
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading,
downloading and saving models as well as a few methods common to all models to:
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
and saving models as well as a few methods common to all models to:
- resize the input embeddings,
- prune heads in the self-attention heads.
Class attributes (overridden by derived classes):
- **config_class** ([`ConfigMixin`]) -- A subclass of [`ConfigMixin`] to use as configuration class
for this model architecture.
- **config_class** ([`ConfigMixin`]) -- A subclass of [`ConfigMixin`] to use as configuration class for this
model architecture.
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
taking as arguments:
@@ -227,8 +227,8 @@ class ModelMixin(torch.nn.Module):
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
user or organization name, like `dbmdz/bert-base-german-cased`.
- A path to a *directory* containing model weights saved using
[`~ModelMixin.save_pretrained`], e.g., `./my_model_directory/`.
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`],
e.g., `./my_model_directory/`.
config (`Union[ConfigMixin, str, os.PathLike]`, *optional*):
Can be either:
@@ -236,13 +236,13 @@ class ModelMixin(torch.nn.Module):
- an instance of a class derived from [`ConfigMixin`],
- a string or path valid as input to [`~ConfigMixin.from_pretrained`].
ConfigMixinuration for the model to use instead of an automatically loaded configuration. ConfigMixinuration can
be automatically loaded when:
ConfigMixinuration for the model to use instead of an automatically loaded configuration.
ConfigMixinuration can be automatically loaded when:
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
model).
- The model was saved using [`~ModelMixin.save_pretrained`] and is reloaded by supplying the
save directory.
- The model was saved using [`~ModelMixin.save_pretrained`] and is reloaded by supplying the save
directory.
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
configuration JSON file named *config.json* is found in the directory.
cache_dir (`Union[str, os.PathLike]`, *optional*):
@@ -292,10 +292,10 @@ class ModelMixin(torch.nn.Module):
underlying model's `__init__` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
initialization function ([`~ConfigMixin.from_pretrained`]). Each key of `kwargs` that
corresponds to a configuration attribute will be used to override said attribute with the
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
will be passed to the underlying model's `__init__` function.
initialization function ([`~ConfigMixin.from_pretrained`]). Each key of `kwargs` that corresponds
to a configuration attribute will be used to override said attribute with the supplied `kwargs`
value. Remaining keys that do not correspond to any configuration attribute will be passed to the
underlying model's `__init__` function.
<Tip>

View File

@@ -22,14 +22,12 @@ def get_timestep_embedding(
timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
Create sinusoidal timestep embeddings.
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"

View File

@@ -58,9 +58,8 @@ class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
@@ -97,9 +96,8 @@ class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
@@ -143,9 +141,8 @@ class GlideUpsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
@@ -171,10 +168,9 @@ class GlideUpsample(nn.Module):
class LDMUpsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D.
If 3D, then
upsampling occurs in the inner-two dimensions.
"""

View File

@@ -82,8 +82,7 @@ def normalization(channels, swish=0.0):
"""
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels.
:return: an nn.Module for normalization.
:param channels: number of input channels. :return: an nn.Module for normalization.
"""
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
@@ -111,8 +110,7 @@ class TimestepBlock(nn.Module):
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
A sequential module that passes timestep embeddings to the children that support it as an extra input.
"""
def forward(self, x, emb, encoder_out=None):
@@ -130,9 +128,8 @@ class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
@@ -158,17 +155,13 @@ class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
downsampling.
"""
def __init__(
@@ -235,8 +228,7 @@ class ResBlock(TimestepBlock):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if self.updown:
@@ -320,8 +312,8 @@ class QKVAttention(nn.Module):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after
attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
@@ -343,29 +335,24 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param
out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
downsampling, attention will be used.
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
for up/downsampling.
"""
def __init__(
@@ -571,10 +558,8 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
:param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param y: an [N]
Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs.
"""
hs = []

View File

@@ -222,11 +222,8 @@ class BasicTransformerBlock(nn.Module):
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
@@ -331,8 +328,7 @@ def normalization(channels, swish=0.0):
"""
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels.
:return: an nn.Module for normalization.
:param channels: number of input channels. :return: an nn.Module for normalization.
"""
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
@@ -382,8 +378,7 @@ class TimestepBlock(nn.Module):
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
A sequential module that passes timestep embeddings to the children that support it as an extra input.
"""
def forward(self, x, emb, context=None):
@@ -399,10 +394,9 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D.
If 3D, then
downsampling occurs in the inner-two dimensions.
"""
@@ -426,18 +420,14 @@ class Downsample(nn.Module):
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
A residual block that can optionally change the number of channels. :param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param
out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use
a spatial
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
downsampling.
"""
def __init__(
@@ -525,8 +515,8 @@ class ResBlock(TimestepBlock):
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
@@ -575,9 +565,8 @@ class QKVAttention(nn.Module):
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
@@ -600,13 +589,9 @@ class QKVAttention(nn.Module):
def count_flops_attn(model, _x, y):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
A counter for the `thop` package to count the operations in an attention operation. Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
model, inputs=(inputs, timestamps), custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b, c, *spatial = y[0].shape
@@ -629,9 +614,8 @@ class QKVAttentionLegacy(nn.Module):
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
@@ -650,31 +634,25 @@ class QKVAttentionLegacy(nn.Module):
class UNetLDMModel(ModelMixin, ConfigMixin):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
rates at which
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
downsampling, attention will be used.
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_new_attention_order: use a different attention pattern for potentially
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
@@ -975,12 +953,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch
of timesteps. :param context: conditioning plugged in via crossattn :param y: an [N] Tensor of labels, if
class-conditional. :return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.num_classes is not None
@@ -1012,8 +987,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
class EncoderUNetModel(nn.Module):
"""
The half UNet model with attention and timestep embedding.
For usage, see UNet.
The half UNet model with attention and timestep embedding. For usage, see UNet.
"""
def __init__(
@@ -1197,10 +1171,8 @@ class EncoderUNetModel(nn.Module):
def forward(self, x, timesteps):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch
of timesteps. :return: an [N x K] Tensor of outputs.
"""
emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)

View File

@@ -111,10 +111,8 @@ class ResidualTemporalBlock(nn.Module):
def forward(self, x, t):
"""
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)

View File

@@ -136,26 +136,21 @@ def naive_downsample_2d(x, factor=2):
def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
Padding is performed only once at the beginning, not between the
operations.
The fused op is considerably more efficient than performing the same
calculation
using standard TensorFlow ops. It supports gradients of arbitrary order.
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels =
x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to
nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or
`[N, H * factor, W * factor, C]`, and same datatype as `x`.
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
`x`.
"""
assert isinstance(factor, int) and factor >= 1
@@ -208,25 +203,21 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
Padding is performed only once at the beginning, not between the operations.
The fused op is considerably more efficient than performing the same
calculation
using standard TensorFlow ops. It supports gradients of arbitrary order.
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels =
x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to
average pooling.
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or
`[N, H // factor, W // factor, C]`, and same datatype as `x`.
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype
as `x`.
"""
assert isinstance(factor, int) and factor >= 1
@@ -258,22 +249,16 @@ def _shape(x, dim):
def upsample_2d(x, k=None, factor=2, gain=1):
r"""Upsample a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
and upsamples each image with the given filter. The filter is normalized so
that
if the input pixels are constant, they will be scaled by the specified
`gain`.
Pixels outside the image are assumed to be zero, and the filter is padded
with
zeros so that its shape is a multiple of the upsampling factor.
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to
nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
@@ -289,22 +274,16 @@ def upsample_2d(x, k=None, factor=2, gain=1):
def downsample_2d(x, k=None, factor=2, gain=1):
r"""Downsample a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
and downsamples each image with the given filter. The filter is normalized
so that
if the input pixels are constant, they will be scaled by the specified
`gain`.
Pixels outside the image are assumed to be zero, and the filter is padded
with
zeros so that its shape is a multiple of the downsampling factor.
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to
average pooling.
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`

View File

@@ -290,7 +290,7 @@ def normalize_numbers(text):
return text
""" from https://github.com/keithito/tacotron """
""" from https://github.com/keithito/tacotron"""
_pad = "_"
@@ -322,8 +322,8 @@ def get_arpabet(word, dictionary):
def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
The text can optionally have ARPAbet sequences enclosed in curly braces embedded in it. For example, "Turn left on
{HH AW1 S S T AH0 N} Street."
Args:
text: string to convert to a sequence

View File

@@ -29,8 +29,7 @@ from ..pipeline_utils import DiffusionPipeline
def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
"""
Embed a diffusion step $t$ into a higher dimensional space
E.g. the embedding vector in the 128-dimensional space is
[sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)),
E.g. the embedding vector in the 128-dimensional space is [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)),
cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))]
Parameters:
@@ -53,8 +52,7 @@ def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
"""
Below scripts were borrowed from
https://github.com/philsyn/DiffWave-Vocoder/blob/master/WaveNet.py
Below scripts were borrowed from https://github.com/philsyn/DiffWave-Vocoder/blob/master/WaveNet.py
"""

View File

@@ -699,9 +699,8 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
:param arr: the 1-D numpy array. :param timesteps: a tensor of indices into the array to extract. :param
broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""

View File

@@ -1,4 +1,4 @@
""" from https://github.com/jaywalnut310/glow-tts """
""" from https://github.com/jaywalnut310/glow-tts"""
import math

View File

@@ -554,11 +554,9 @@ class LDMBertModel(LDMBertPreTrainedModel):
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal
embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section
3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
@@ -1055,8 +1053,8 @@ class Decoder(nn.Module):
class VectorQuantizer(nn.Module):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for

View File

@@ -25,13 +25,12 @@ from .scheduling_utils import SchedulerMixin
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
from 0 to 1 and
produces the cumulative product of (1-beta) up to that part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""

View File

@@ -25,13 +25,12 @@ from .scheduling_utils import SchedulerMixin
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
from 0 to 1 and
produces the cumulative product of (1-beta) up to that part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""

View File

@@ -24,13 +24,12 @@ from .scheduling_utils import SchedulerMixin
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
from 0 to 1 and
produces the cumulative product of (1-beta) up to that part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""

View File

@@ -20,11 +20,10 @@ class EMAModel:
):
"""
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
at 215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.

View File

@@ -89,20 +89,20 @@ class RevisionNotFoundError(HTTPError):
TRANSFORMERS_IMPORT_ERROR = """
{0} requires the transformers library but it was not found in your environment. You can install it with pip:
`pip install transformers`
{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
install transformers`
"""
UNIDECODE_IMPORT_ERROR = """
{0} requires the unidecode library but it was not found in your environment. You can install it with pip:
`pip install Unidecode`
{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
Unidecode`
"""
INFLECT_IMPORT_ERROR = """
{0} requires the inflect library but it was not found in your environment. You can install it with pip:
`pip install inflect`
{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
inflect`
"""

View File

@@ -3,7 +3,7 @@
from ..utils import DummyObject, requires_backends
class GradTTS(metaclass=DummyObject):
class GradTTSPipeline(metaclass=DummyObject):
_backends = ["transformers", "inflect", "unidecode"]
def __init__(self, *args, **kwargs):

View File

@@ -31,14 +31,14 @@ class UNetGradTTSModel(metaclass=DummyObject):
requires_backends(self, ["transformers"])
class Glide(metaclass=DummyObject):
class GlidePipeline(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
class LatentDiffusion(metaclass=DummyObject):
class LatentDiffusionPipeline(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):

View File

@@ -233,8 +233,8 @@ def disable_propagation() -> None:
def enable_propagation() -> None:
"""
Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to
prevent double logging if the root logger has been configured.
Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent
double logging if the root logger has been configured.
"""
_configure_library_root_logger()

View File

@@ -22,7 +22,6 @@ import numpy as np
import torch
from diffusers import (
GradTTSPipeline,
BDDMPipeline,
DDIMPipeline,
DDIMScheduler,
@@ -31,6 +30,7 @@ from diffusers import (
GlidePipeline,
GlideSuperResUNetModel,
GlideTextToImageUNetModel,
GradTTSPipeline,
GradTTSScheduler,
LatentDiffusionPipeline,
NCSNpp,

View File

@@ -24,7 +24,7 @@ from doc_builder.style_doc import style_docstrings_in_code
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_copies.py
TRANSFORMERS_PATH = "src/transformers"
TRANSFORMERS_PATH = "src/diffusers"
PATH_TO_DOCS = "docs/source/en"
REPO_PATH = "."
@@ -76,7 +76,7 @@ def _should_continue(line, indent):
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
def find_code_in_transformers(object_name):
def find_code_in_diffusers(object_name):
"""Find and return the code source code of `object_name`."""
parts = object_name.split(".")
i = 0
@@ -88,9 +88,7 @@ def find_code_in_transformers(object_name):
if i < len(parts):
module = os.path.join(module, parts[i])
if i >= len(parts):
raise ValueError(
f"`object_name` should begin with the name of a module of transformers but got {object_name}."
)
raise ValueError(f"`object_name` should begin with the name of a module of diffusers but got {object_name}.")
with open(os.path.join(TRANSFORMERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
@@ -121,7 +119,7 @@ def find_code_in_transformers(object_name):
return "".join(code_lines)
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)")
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+diffusers\.(\S+\.\S+)\s*($|\S.*$)")
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
@@ -167,7 +165,7 @@ def is_copy_consistent(filename, overwrite=False):
# There is some copied code here, let's retrieve the original.
indent, object_name, replace_pattern = search.groups()
theoretical_code = find_code_in_transformers(object_name)
theoretical_code = find_code_in_diffusers(object_name)
theoretical_indent = get_indent(theoretical_code)
start_index = line_index + 1 if indent == theoretical_indent else line_index + 2
@@ -235,7 +233,9 @@ def check_copies(overwrite: bool = False):
+ diff
+ "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them."
)
check_model_list_copy(overwrite=overwrite)
# check_model_list_copy(overwrite=overwrite)
def check_full_copies(overwrite: bool = False):
@@ -348,8 +348,8 @@ def convert_to_localized_md(model_list, localized_model_list, format_str):
def convert_readme_to_index(model_list):
model_list = model_list.replace("https://huggingface.co/docs/transformers/main/", "")
return model_list.replace("https://huggingface.co/docs/transformers/", "")
model_list = model_list.replace("https://huggingface.co/docs/diffusers/main/", "")
return model_list.replace("https://huggingface.co/docs/diffusers/", "")
def _find_text_in_file(filename, start_prompt, end_prompt):
@@ -383,9 +383,9 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
# Fix potential doc links in the README
with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f:
readme = f.read()
new_readme = readme.replace("https://huggingface.co/transformers", "https://huggingface.co/docs/transformers")
new_readme = readme.replace("https://huggingface.co/diffusers", "https://huggingface.co/docs/diffusers")
new_readme = new_readme.replace(
"https://huggingface.co/docs/main/transformers", "https://huggingface.co/docs/transformers/main"
"https://huggingface.co/docs/main/diffusers", "https://huggingface.co/docs/diffusers/main"
)
if new_readme != readme:
if overwrite:

250
utils/custom_init_isort.py Normal file
View File

@@ -0,0 +1,250 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import re
PATH_TO_TRANSFORMERS = "src/diffusers"
# Pattern that looks at the indentation in a line.
_re_indent = re.compile(r"^(\s*)\S")
# Pattern that matches `"key":" and puts `key` in group 0.
_re_direct_key = re.compile(r'^\s*"([^"]+)":')
# Pattern that matches `_import_structure["key"]` and puts `key` in group 0.
_re_indirect_key = re.compile(r'^\s*_import_structure\["([^"]+)"\]')
# Pattern that matches `"key",` and puts `key` in group 0.
_re_strip_line = re.compile(r'^\s*"([^"]+)",\s*$')
# Pattern that matches any `[stuff]` and puts `stuff` in group 0.
_re_bracket_content = re.compile(r"\[([^\]]+)\]")
def get_indent(line):
"""Returns the indent in `line`."""
search = _re_indent.search(line)
return "" if search is None else search.groups()[0]
def split_code_in_indented_blocks(code, indent_level="", start_prompt=None, end_prompt=None):
"""
Split `code` into its indented blocks, starting at `indent_level`. If provided, begins splitting after
`start_prompt` and stops at `end_prompt` (but returns what's before `start_prompt` as a first block and what's
after `end_prompt` as a last block, so `code` is always the same as joining the result of this function).
"""
# Let's split the code into lines and move to start_index.
index = 0
lines = code.split("\n")
if start_prompt is not None:
while not lines[index].startswith(start_prompt):
index += 1
blocks = ["\n".join(lines[:index])]
else:
blocks = []
# We split into blocks until we get to the `end_prompt` (or the end of the block).
current_block = [lines[index]]
index += 1
while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)):
if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level:
if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "):
current_block.append(lines[index])
blocks.append("\n".join(current_block))
if index < len(lines) - 1:
current_block = [lines[index + 1]]
index += 1
else:
current_block = []
else:
blocks.append("\n".join(current_block))
current_block = [lines[index]]
else:
current_block.append(lines[index])
index += 1
# Adds current block if it's nonempty.
if len(current_block) > 0:
blocks.append("\n".join(current_block))
# Add final block after end_prompt if provided.
if end_prompt is not None and index < len(lines):
blocks.append("\n".join(lines[index:]))
return blocks
def ignore_underscore(key):
"Wraps a `key` (that maps an object to string) to lower case and remove underscores."
def _inner(x):
return key(x).lower().replace("_", "")
return _inner
def sort_objects(objects, key=None):
"Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str."
# If no key is provided, we use a noop.
def noop(x):
return x
if key is None:
key = noop
# Constants are all uppercase, they go first.
constants = [obj for obj in objects if key(obj).isupper()]
# Classes are not all uppercase but start with a capital, they go second.
classes = [obj for obj in objects if key(obj)[0].isupper() and not key(obj).isupper()]
# Functions begin with a lowercase, they go last.
functions = [obj for obj in objects if not key(obj)[0].isupper()]
key1 = ignore_underscore(key)
return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1)
def sort_objects_in_import(import_statement):
"""
Return the same `import_statement` but with objects properly sorted.
"""
# This inner function sort imports between [ ].
def _replace(match):
imports = match.groups()[0]
if "," not in imports:
return f"[{imports}]"
keys = [part.strip().replace('"', "") for part in imports.split(",")]
# We will have a final empty element if the line finished with a comma.
if len(keys[-1]) == 0:
keys = keys[:-1]
return "[" + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + "]"
lines = import_statement.split("\n")
if len(lines) > 3:
# Here we have to sort internal imports that are on several lines (one per name):
# key: [
# "object1",
# "object2",
# ...
# ]
# We may have to ignore one or two lines on each side.
idx = 2 if lines[1].strip() == "[" else 1
keys_to_sort = [(i, _re_strip_line.search(line).groups()[0]) for i, line in enumerate(lines[idx:-idx])]
sorted_indices = sort_objects(keys_to_sort, key=lambda x: x[1])
sorted_lines = [lines[x[0] + idx] for x in sorted_indices]
return "\n".join(lines[:idx] + sorted_lines + lines[-idx:])
elif len(lines) == 3:
# Here we have to sort internal imports that are on one separate line:
# key: [
# "object1", "object2", ...
# ]
if _re_bracket_content.search(lines[1]) is not None:
lines[1] = _re_bracket_content.sub(_replace, lines[1])
else:
keys = [part.strip().replace('"', "") for part in lines[1].split(",")]
# We will have a final empty element if the line finished with a comma.
if len(keys[-1]) == 0:
keys = keys[:-1]
lines[1] = get_indent(lines[1]) + ", ".join([f'"{k}"' for k in sort_objects(keys)])
return "\n".join(lines)
else:
# Finally we have to deal with imports fitting on one line
import_statement = _re_bracket_content.sub(_replace, import_statement)
return import_statement
def sort_imports(file, check_only=True):
"""
Sort `_import_structure` imports in `file`, `check_only` determines if we only check or overwrite.
"""
with open(file, "r") as f:
code = f.read()
if "_import_structure" not in code:
return
# Blocks of indent level 0
main_blocks = split_code_in_indented_blocks(
code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:"
)
# We ignore block 0 (everything untils start_prompt) and the last block (everything after end_prompt).
for block_idx in range(1, len(main_blocks) - 1):
# Check if the block contains some `_import_structure`s thingy to sort.
block = main_blocks[block_idx]
block_lines = block.split("\n")
# Get to the start of the imports.
line_idx = 0
while line_idx < len(block_lines) and "_import_structure" not in block_lines[line_idx]:
# Skip dummy import blocks
if "import dummy" in block_lines[line_idx]:
line_idx = len(block_lines)
else:
line_idx += 1
if line_idx >= len(block_lines):
continue
# Ignore beginning and last line: they don't contain anything.
internal_block_code = "\n".join(block_lines[line_idx:-1])
indent = get_indent(block_lines[1])
# Slit the internal block into blocks of indent level 1.
internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent)
# We have two categories of import key: list or _import_structu[key].append/extend
pattern = _re_direct_key if "_import_structure" in block_lines[0] else _re_indirect_key
# Grab the keys, but there is a trap: some lines are empty or jsut comments.
keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks]
# We only sort the lines with a key.
keys_to_sort = [(i, key) for i, key in enumerate(keys) if key is not None]
sorted_indices = [x[0] for x in sorted(keys_to_sort, key=lambda x: x[1])]
# We reorder the blocks by leaving empty lines/comments as they were and reorder the rest.
count = 0
reorderded_blocks = []
for i in range(len(internal_blocks)):
if keys[i] is None:
reorderded_blocks.append(internal_blocks[i])
else:
block = sort_objects_in_import(internal_blocks[sorted_indices[count]])
reorderded_blocks.append(block)
count += 1
# And we put our main block back together with its first and last line.
main_blocks[block_idx] = "\n".join(block_lines[:line_idx] + reorderded_blocks + [block_lines[-1]])
if code != "\n".join(main_blocks):
if check_only:
return True
else:
print(f"Overwriting {file}.")
with open(file, "w") as f:
f.write("\n".join(main_blocks))
def sort_imports_in_all_inits(check_only=True):
failures = []
for root, _, files in os.walk(PATH_TO_TRANSFORMERS):
if "__init__.py" in files:
result = sort_imports(os.path.join(root, "__init__.py"), check_only=check_only)
if result:
failures = [os.path.join(root, "__init__.py")]
if len(failures) > 0:
raise ValueError(f"Would overwrite {len(failures)} files, run `make style`.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
args = parser.parse_args()
sort_imports_in_all_inits(check_only=args.check_only)