mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Improve typehints and docs in diffusers/models (#5312)
* improvement: add missing typehints and docs to diffusers/models/attention.py * chore: convert doc strings to raw python strings add missing typehints * improvement: add missing typehints and docs to diffusers/models/adapter.py * improvement: add missing typehints and docs to diffusers/models/lora.py * docs: include suggestion by @sayakpaul in src/diffusers/models/adapter.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * docs: include suggestion by @sayakpaul in src/diffusers/models/adapter.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * docs: include suggestion by @sayakpaul in src/diffusers/models/adapter.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * docs: include suggestion by @sayakpaul in src/diffusers/models/adapter.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update src/diffusers/models/lora.py --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -231,7 +231,11 @@ class T2IAdapter(ModelMixin, ConfigMixin):
|
||||
The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will
|
||||
also determine the number of downsample blocks in the Adapter.
|
||||
num_res_blocks (`int`, *optional*, defaults to 2):
|
||||
Number of ResNet blocks in each downsample block
|
||||
Number of ResNet blocks in each downsample block.
|
||||
downscale_factor (`int`, *optional*, defaults to 8):
|
||||
A factor that determines the total downscale factor of the Adapter.
|
||||
adapter_type (`str`, *optional*, defaults to `full_adapter`):
|
||||
The type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -275,6 +279,10 @@ class T2IAdapter(ModelMixin, ConfigMixin):
|
||||
|
||||
|
||||
class FullAdapter(nn.Module):
|
||||
r"""
|
||||
See [`T2IAdapter`] for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
@@ -321,6 +329,10 @@ class FullAdapter(nn.Module):
|
||||
|
||||
|
||||
class FullAdapterXL(nn.Module):
|
||||
r"""
|
||||
See [`T2IAdapter`] for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
@@ -367,7 +379,22 @@ class FullAdapterXL(nn.Module):
|
||||
|
||||
|
||||
class AdapterBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
|
||||
r"""
|
||||
An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and
|
||||
`FullAdapterXL` models.
|
||||
|
||||
Parameters:
|
||||
in_channels (`int`):
|
||||
Number of channels of AdapterBlock's input.
|
||||
out_channels (`int`):
|
||||
Number of channels of AdapterBlock's output.
|
||||
num_res_blocks (`int`):
|
||||
Number of ResNet blocks in the AdapterBlock.
|
||||
down (`bool`, *optional*, defaults to `False`):
|
||||
Whether to perform downsampling on AdapterBlock's input.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.downsample = None
|
||||
@@ -382,7 +409,7 @@ class AdapterBlock(nn.Module):
|
||||
*[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)],
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
This method takes tensor x as input and performs operations downsampling and convolutional layers if the
|
||||
self.downsample and self.in_conv properties of AdapterBlock model are specified. Then it applies a series of
|
||||
@@ -400,13 +427,21 @@ class AdapterBlock(nn.Module):
|
||||
|
||||
|
||||
class AdapterResnetBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
r"""
|
||||
An `AdapterResnetBlock` is a helper model that implements a ResNet-like block.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
Number of channels of AdapterResnetBlock's input and output.
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int):
|
||||
super().__init__()
|
||||
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
||||
self.act = nn.ReLU()
|
||||
self.block2 = nn.Conv2d(channels, channels, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
This method takes input tensor x and applies a convolutional layer, ReLU activation, and another convolutional
|
||||
layer on the input tensor. It returns addition with the input tensor.
|
||||
@@ -423,6 +458,10 @@ class AdapterResnetBlock(nn.Module):
|
||||
|
||||
|
||||
class LightAdapter(nn.Module):
|
||||
r"""
|
||||
See [`T2IAdapter`] for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
@@ -449,7 +488,7 @@ class LightAdapter(nn.Module):
|
||||
|
||||
self.total_downscale_factor = downscale_factor * (2 ** len(channels))
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
r"""
|
||||
This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each
|
||||
feature tensor corresponds to a different level of processing within the LightAdapter.
|
||||
@@ -466,7 +505,22 @@ class LightAdapter(nn.Module):
|
||||
|
||||
|
||||
class LightAdapterBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
|
||||
r"""
|
||||
A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the
|
||||
`LightAdapter` model.
|
||||
|
||||
Parameters:
|
||||
in_channels (`int`):
|
||||
Number of channels of LightAdapterBlock's input.
|
||||
out_channels (`int`):
|
||||
Number of channels of LightAdapterBlock's output.
|
||||
num_res_blocks (`int`):
|
||||
Number of LightAdapterResnetBlocks in the LightAdapterBlock.
|
||||
down (`bool`, *optional*, defaults to `False`):
|
||||
Whether to perform downsampling on LightAdapterBlock's input.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
|
||||
super().__init__()
|
||||
mid_channels = out_channels // 4
|
||||
|
||||
@@ -478,7 +532,7 @@ class LightAdapterBlock(nn.Module):
|
||||
self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
|
||||
self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
This method takes tensor x as input and performs downsampling if required. Then it applies in convolution
|
||||
layer, a sequence of residual blocks, and out convolutional layer.
|
||||
@@ -494,13 +548,22 @@ class LightAdapterBlock(nn.Module):
|
||||
|
||||
|
||||
class LightAdapterResnetBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
"""
|
||||
A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different
|
||||
architecture than `AdapterResnetBlock`.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
Number of channels of LightAdapterResnetBlock's input and output.
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int):
|
||||
super().__init__()
|
||||
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
||||
self.act = nn.ReLU()
|
||||
self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
This function takes input tensor x and processes it through one convolutional layer, ReLU activation, and
|
||||
another convolutional layer and adds it to input tensor.
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -26,7 +26,17 @@ from .lora import LoRACompatibleLinear
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class GatedSelfAttentionDense(nn.Module):
|
||||
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||
r"""
|
||||
A gated self-attention dense layer that combines visual features and object features.
|
||||
|
||||
Parameters:
|
||||
query_dim (`int`): The number of channels in the query.
|
||||
context_dim (`int`): The number of channels in the context.
|
||||
n_heads (`int`): The number of heads to use for attention.
|
||||
d_head (`int`): The number of channels in each head.
|
||||
"""
|
||||
|
||||
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
|
||||
super().__init__()
|
||||
|
||||
# we need a linear projection since we need cat visual feature and obj feature
|
||||
@@ -43,7 +53,7 @@ class GatedSelfAttentionDense(nn.Module):
|
||||
|
||||
self.enabled = True
|
||||
|
||||
def forward(self, x, objs):
|
||||
def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
|
||||
if not self.enabled:
|
||||
return x
|
||||
|
||||
@@ -67,15 +77,25 @@ class BasicTransformerBlock(nn.Module):
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||
final_dropout (`bool` *optional*, defaults to False):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
attention_type (`str`, *optional*, defaults to `"default"`):
|
||||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -175,7 +195,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
) -> torch.FloatTensor:
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 0. Self-Attention
|
||||
if self.use_ada_layer_norm:
|
||||
@@ -301,7 +321,7 @@ class FeedForward(nn.Module):
|
||||
if final_dropout:
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
|
||||
def forward(self, hidden_states, scale: float = 1.0):
|
||||
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
for module in self.net:
|
||||
if isinstance(module, (LoRACompatibleLinear, GEGLU)):
|
||||
hidden_states = module(hidden_states, scale)
|
||||
@@ -313,6 +333,11 @@ class FeedForward(nn.Module):
|
||||
class GELU(nn.Module):
|
||||
r"""
|
||||
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
||||
|
||||
Parameters:
|
||||
dim_in (`int`): The number of channels in the input.
|
||||
dim_out (`int`): The number of channels in the output.
|
||||
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
|
||||
@@ -320,7 +345,7 @@ class GELU(nn.Module):
|
||||
self.proj = nn.Linear(dim_in, dim_out)
|
||||
self.approximate = approximate
|
||||
|
||||
def gelu(self, gate):
|
||||
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||
if gate.device.type != "mps":
|
||||
return F.gelu(gate, approximate=self.approximate)
|
||||
# mps: gelu is not implemented for float16
|
||||
@@ -345,7 +370,7 @@ class GEGLU(nn.Module):
|
||||
super().__init__()
|
||||
self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)
|
||||
|
||||
def gelu(self, gate):
|
||||
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||
if gate.device.type != "mps":
|
||||
return F.gelu(gate)
|
||||
# mps: gelu is not implemented for float16
|
||||
@@ -357,34 +382,41 @@ class GEGLU(nn.Module):
|
||||
|
||||
|
||||
class ApproximateGELU(nn.Module):
|
||||
"""
|
||||
The approximate form of Gaussian Error Linear Unit (GELU)
|
||||
r"""
|
||||
The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
|
||||
https://arxiv.org/abs/1606.08415.
|
||||
|
||||
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
||||
Parameters:
|
||||
dim_in (`int`): The number of channels in the input.
|
||||
dim_out (`int`): The number of channels in the output.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
"""
|
||||
r"""
|
||||
Norm layer modified to incorporate timestep embeddings.
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the dictionary of embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, num_embeddings):
|
||||
def __init__(self, embedding_dim: int, num_embeddings: int):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
||||
|
||||
def forward(self, x, timestep):
|
||||
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear(self.silu(self.emb(timestep)))
|
||||
scale, shift = torch.chunk(emb, 2)
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
@@ -392,11 +424,15 @@ class AdaLayerNorm(nn.Module):
|
||||
|
||||
|
||||
class AdaLayerNormZero(nn.Module):
|
||||
"""
|
||||
r"""
|
||||
Norm layer adaptive layer norm zero (adaLN-Zero).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the dictionary of embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, num_embeddings):
|
||||
def __init__(self, embedding_dim: int, num_embeddings: int):
|
||||
super().__init__()
|
||||
|
||||
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
||||
@@ -405,7 +441,13 @@ class AdaLayerNormZero(nn.Module):
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, timestep, class_labels, hidden_dtype=None):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
class_labels: torch.LongTensor,
|
||||
hidden_dtype: Optional[torch.dtype] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
@@ -413,8 +455,15 @@ class AdaLayerNormZero(nn.Module):
|
||||
|
||||
|
||||
class AdaGroupNorm(nn.Module):
|
||||
"""
|
||||
r"""
|
||||
GroupNorm layer modified to incorporate timestep embeddings.
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the dictionary of embeddings.
|
||||
num_groups (`int`): The number of groups to separate the channels into.
|
||||
act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
|
||||
eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -431,7 +480,7 @@ class AdaGroupNorm(nn.Module):
|
||||
|
||||
self.linear = nn.Linear(embedding_dim, out_dim * 2)
|
||||
|
||||
def forward(self, x, emb):
|
||||
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
||||
if self.act:
|
||||
emb = self.act(emb)
|
||||
emb = self.linear(emb)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -40,7 +40,35 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
|
||||
|
||||
|
||||
class LoRALinearLayer(nn.Module):
|
||||
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
|
||||
r"""
|
||||
A linear layer that is used with LoRA.
|
||||
|
||||
Parameters:
|
||||
in_features (`int`):
|
||||
Number of input features.
|
||||
out_features (`int`):
|
||||
Number of output features.
|
||||
rank (`int`, `optional`, defaults to 4):
|
||||
The rank of the LoRA layer.
|
||||
network_alpha (`float`, `optional`, defaults to `None`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the same
|
||||
meaning as the `--network_alpha` option in the kohya-ss trainer script. See
|
||||
https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
device (`torch.device`, `optional`, defaults to `None`):
|
||||
The device to use for the layer's weights.
|
||||
dtype (`torch.dtype`, `optional`, defaults to `None`):
|
||||
The dtype to use for the layer's weights.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
rank: int = 4,
|
||||
network_alpha: Optional[float] = None,
|
||||
device: Optional[Union[torch.device, str]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
||||
@@ -55,7 +83,7 @@ class LoRALinearLayer(nn.Module):
|
||||
nn.init.normal_(self.down.weight, std=1 / rank)
|
||||
nn.init.zeros_(self.up.weight)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = hidden_states.dtype
|
||||
dtype = self.down.weight.dtype
|
||||
|
||||
@@ -69,8 +97,37 @@ class LoRALinearLayer(nn.Module):
|
||||
|
||||
|
||||
class LoRAConv2dLayer(nn.Module):
|
||||
r"""
|
||||
A convolutional layer that is used with LoRA.
|
||||
|
||||
Parameters:
|
||||
in_features (`int`):
|
||||
Number of input features.
|
||||
out_features (`int`):
|
||||
Number of output features.
|
||||
rank (`int`, `optional`, defaults to 4):
|
||||
The rank of the LoRA layer.
|
||||
kernel_size (`int` or `tuple` of two `int`, `optional`, defaults to 1):
|
||||
The kernel size of the convolution.
|
||||
stride (`int` or `tuple` of two `int`, `optional`, defaults to 1):
|
||||
The stride of the convolution.
|
||||
padding (`int` or `tuple` of two `int` or `str`, `optional`, defaults to 0):
|
||||
The padding of the convolution.
|
||||
network_alpha (`float`, `optional`, defaults to `None`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the same
|
||||
meaning as the `--network_alpha` option in the kohya-ss trainer script. See
|
||||
https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
rank: int = 4,
|
||||
kernel_size: Union[int, Tuple[int, int]] = (1, 1),
|
||||
stride: Union[int, Tuple[int, int]] = (1, 1),
|
||||
padding: Union[int, Tuple[int, int], str] = 0,
|
||||
network_alpha: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -87,7 +144,7 @@ class LoRAConv2dLayer(nn.Module):
|
||||
nn.init.normal_(self.down.weight, std=1 / rank)
|
||||
nn.init.zeros_(self.up.weight)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = hidden_states.dtype
|
||||
dtype = self.down.weight.dtype
|
||||
|
||||
@@ -112,7 +169,7 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
|
||||
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
|
||||
if self.lora_layer is None:
|
||||
return
|
||||
|
||||
@@ -164,7 +221,7 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
self.w_up = None
|
||||
self.w_down = None
|
||||
|
||||
def forward(self, hidden_states, scale: float = 1.0):
|
||||
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
if self.lora_layer is None:
|
||||
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
|
||||
# see: https://github.com/huggingface/diffusers/pull/4315
|
||||
@@ -190,7 +247,7 @@ class LoRACompatibleLinear(nn.Linear):
|
||||
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
|
||||
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
|
||||
if self.lora_layer is None:
|
||||
return
|
||||
|
||||
@@ -238,7 +295,7 @@ class LoRACompatibleLinear(nn.Linear):
|
||||
self.w_up = None
|
||||
self.w_down = None
|
||||
|
||||
def forward(self, hidden_states, scale: float = 1.0):
|
||||
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
if self.lora_layer is None:
|
||||
out = super().forward(hidden_states)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user