mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Pass LoRA rank to LoRALinearLayer (#2191)
This commit is contained in:
@@ -296,10 +296,10 @@ class LoRACrossAttnProcessor(nn.Module):
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
|
||||
super().__init__()
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
|
||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
|
||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||
|
||||
def __call__(
|
||||
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
|
||||
@@ -408,10 +408,10 @@ class LoRAXFormersCrossAttnProcessor(nn.Module):
|
||||
def __init__(self, hidden_size, cross_attention_dim, rank=4):
|
||||
super().__init__()
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
|
||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
|
||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||
|
||||
def __call__(
|
||||
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
|
||||
|
||||
Reference in New Issue
Block a user