mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Bugfix] fix error of peft lora when xformers enabled (#5697)
* bugfix peft lor * Apply suggestions from code review --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -879,6 +879,9 @@ class AttnAddedKVProcessor:
|
||||
scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
@@ -891,17 +894,17 @@ class AttnAddedKVProcessor:
|
||||
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states, scale=scale)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, scale=scale)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, scale=scale)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)
|
||||
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
||||
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
||||
|
||||
if not attn.only_cross_attention:
|
||||
key = attn.to_k(hidden_states, scale=scale)
|
||||
value = attn.to_v(hidden_states, scale=scale)
|
||||
key = attn.to_k(hidden_states, *args)
|
||||
value = attn.to_v(hidden_states, *args)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
@@ -915,7 +918,7 @@ class AttnAddedKVProcessor:
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, scale=scale)
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -946,6 +949,9 @@ class AttnAddedKVProcessor2_0:
|
||||
scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
@@ -958,7 +964,7 @@ class AttnAddedKVProcessor2_0:
|
||||
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states, scale=scale)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
query = attn.head_to_batch_dim(query, out_dim=4)
|
||||
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
@@ -967,8 +973,8 @@ class AttnAddedKVProcessor2_0:
|
||||
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
|
||||
|
||||
if not attn.only_cross_attention:
|
||||
key = attn.to_k(hidden_states, scale=scale)
|
||||
value = attn.to_v(hidden_states, scale=scale)
|
||||
key = attn.to_k(hidden_states, *args)
|
||||
value = attn.to_v(hidden_states, *args)
|
||||
key = attn.head_to_batch_dim(key, out_dim=4)
|
||||
value = attn.head_to_batch_dim(value, out_dim=4)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
@@ -985,7 +991,7 @@ class AttnAddedKVProcessor2_0:
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, scale=scale)
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -1177,6 +1183,8 @@ class AttnProcessor2_0:
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
@@ -1207,12 +1215,8 @@ class AttnProcessor2_0:
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = (
|
||||
attn.to_k(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_k(encoder_hidden_states)
|
||||
)
|
||||
value = (
|
||||
attn.to_v(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_v(encoder_hidden_states)
|
||||
)
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
@@ -1232,9 +1236,7 @@ class AttnProcessor2_0:
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = (
|
||||
attn.to_out[0](hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_out[0](hidden_states)
|
||||
)
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user