1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Alexander Pivovarov
2023-03-21 12:45:04 +00:00
committed by GitHub
parent 2120b4eee3
commit f024e00398
12 changed files with 44 additions and 44 deletions

View File

@@ -69,7 +69,7 @@ class AttentionBlock(nn.Module):
self.value = nn.Linear(channels, channels)
self.rescale_output_factor = rescale_output_factor
self.proj_attn = nn.Linear(channels, channels, 1)
self.proj_attn = nn.Linear(channels, channels, bias=True)
self._use_memory_efficient_attention_xformers = False
self._attention_op = None

View File

@@ -344,7 +344,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
"""
count = len(self.attn_processors.keys())
@@ -379,24 +379,24 @@ class ControlNetModel(ModelMixin, ConfigMixin):
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_slicable_dims(child)
fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_slicable_dims(module)
fn_recursive_retrieve_sliceable_dims(module)
num_slicable_layers = len(sliceable_head_dims)
num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
@@ -404,9 +404,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_slicable_layers * [1]
slice_size = num_sliceable_layers * [1]
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(

View File

@@ -575,7 +575,7 @@ class ModelMixin(torch.nn.Module):
raise ValueError(
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
" those weights or else make sure your checkpoint file is correct."
)
@@ -591,7 +591,7 @@ class ModelMixin(torch.nn.Module):
set_module_tensor_to_device(model, param_name, param_device, value=param)
else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by deafult the device_map is None and the weights are loaded on the CPU
# by default the device_map is None and the weights are loaded on the CPU
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype)
loading_info = {

View File

@@ -418,7 +418,7 @@ class ResnetBlock2D(nn.Module):
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
"ada_group" for a stronger conditioning with scale and shift.
kernal (`torch.FloatTensor`, optional, default to None): FIR filter, see
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
use_in_shortcut (`bool`, *optional*, default to `True`):

View File

@@ -105,7 +105,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
self.is_input_vectorized = num_vector_embeds is not None
@@ -198,7 +198,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
if self.is_input_continuous:
# TODO: should use out_channels for continous projections
# TODO: should use out_channels for continuous projections
if use_linear_projection:
self.proj_out = nn.Linear(inner_dim, in_channels)
else:
@@ -223,7 +223,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
"""
Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to

View File

@@ -59,7 +59,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
obj:`(32, 32, 64)`): Tuple of block output channels.
mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet.
out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet.
act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks.
act_fn (`str`, *optional*, defaults to None): optional activation function in UNet blocks.
norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks.
layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block.
downsample_each_block (`int`, *optional*, defaults to False:

View File

@@ -331,7 +331,7 @@ class SelfAttention1d(nn.Module):
self.key = nn.Linear(self.channels, self.channels)
self.value = nn.Linear(self.channels, self.channels)
self.proj_attn = nn.Linear(self.channels, self.channels, 1)
self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
self.dropout = nn.Dropout(dropout_rate, inplace=True)

View File

@@ -2684,7 +2684,7 @@ class KAttentionBlock(nn.Module):
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
cross_attention_norm=None,
cross_attention_norm=False,
)
# 2. Cross-Attn

View File

@@ -197,7 +197,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
timestep_input_dim = block_out_channels[0]
else:
raise ValueError(
f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`."
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
)
self.time_embedding = TimestepEmbedding(
@@ -391,7 +391,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
"""
count = len(self.attn_processors.keys())
@@ -425,24 +425,24 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_slicable_dims(child)
fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_slicable_dims(module)
fn_recursive_retrieve_sliceable_dims(module)
num_slicable_layers = len(sliceable_head_dims)
num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
@@ -450,9 +450,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_slicable_layers * [1]
slice_size = num_sliceable_layers * [1]
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
@@ -515,7 +515,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
returning a tuple, the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers

View File

@@ -1351,7 +1351,7 @@ class DiffusionPipeline(ConfigMixin):
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""

View File

@@ -287,7 +287,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
timestep_input_dim = block_out_channels[0]
else:
raise ValueError(
f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`."
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
)
self.time_embedding = TimestepEmbedding(
@@ -481,7 +481,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
"""
count = len(self.attn_processors.keys())
@@ -515,24 +515,24 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_slicable_dims(child)
fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_slicable_dims(module)
fn_recursive_retrieve_sliceable_dims(module)
num_slicable_layers = len(sliceable_head_dims)
num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
@@ -540,9 +540,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_slicable_layers * [1]
slice_size = num_sliceable_layers * [1]
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
@@ -605,7 +605,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
returning a tuple, the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers

View File

@@ -223,23 +223,23 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
output = model(**inputs_dict)
assert output is not None
def test_model_slicable_head_dim(self):
def test_model_sliceable_head_dim(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
def check_slicable_dim_attr(module: torch.nn.Module):
def check_sliceable_dim_attr(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
assert isinstance(module.sliceable_head_dim, int)
for child in module.children():
check_slicable_dim_attr(child)
check_sliceable_dim_attr(child)
# retrieve number of attention layers
for module in model.children():
check_slicable_dim_attr(module)
check_sliceable_dim_attr(module)
def test_special_attn_proc(self):
class AttnEasyProc(torch.nn.Module):
@@ -658,7 +658,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# there are 32 slicable layers
# there are 32 sliceable layers
slice_list = 16 * [2, 3]
unet = self.get_unet_model()
unet.set_attention_slice(slice_list)