diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 7fb524110e..1234dbd2d5 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -879,6 +879,9 @@ class AttnAddedKVProcessor: scale: float = 1.0, ) -> torch.Tensor: residual = hidden_states + + args = () if USE_PEFT_BACKEND else (scale,) + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape @@ -891,17 +894,17 @@ class AttnAddedKVProcessor: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, scale=scale) + query = attn.to_q(hidden_states, *args) query = attn.head_to_batch_dim(query) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, scale=scale) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, scale=scale) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) if not attn.only_cross_attention: - key = attn.to_k(hidden_states, scale=scale) - value = attn.to_v(hidden_states, scale=scale) + key = attn.to_k(hidden_states, *args) + value = attn.to_v(hidden_states, *args) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) @@ -915,7 +918,7 @@ class AttnAddedKVProcessor: hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, scale=scale) + hidden_states = attn.to_out[0](hidden_states, *args) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -946,6 +949,9 @@ class AttnAddedKVProcessor2_0: scale: float = 1.0, ) -> torch.Tensor: residual = hidden_states + + args = () if USE_PEFT_BACKEND else (scale,) + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape @@ -958,7 +964,7 @@ class AttnAddedKVProcessor2_0: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, scale=scale) + query = attn.to_q(hidden_states, *args) query = attn.head_to_batch_dim(query, out_dim=4) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) @@ -967,8 +973,8 @@ class AttnAddedKVProcessor2_0: encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) if not attn.only_cross_attention: - key = attn.to_k(hidden_states, scale=scale) - value = attn.to_v(hidden_states, scale=scale) + key = attn.to_k(hidden_states, *args) + value = attn.to_v(hidden_states, *args) key = attn.head_to_batch_dim(key, out_dim=4) value = attn.head_to_batch_dim(value, out_dim=4) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) @@ -985,7 +991,7 @@ class AttnAddedKVProcessor2_0: hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) # linear proj - hidden_states = attn.to_out[0](hidden_states, scale=scale) + hidden_states = attn.to_out[0](hidden_states, *args) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -1177,6 +1183,8 @@ class AttnProcessor2_0: ) -> torch.FloatTensor: residual = hidden_states + args = () if USE_PEFT_BACKEND else (scale,) + if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -1207,12 +1215,8 @@ class AttnProcessor2_0: elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = ( - attn.to_k(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_k(encoder_hidden_states) - ) - value = ( - attn.to_v(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_v(encoder_hidden_states) - ) + key = attn.to_k(encoder_hidden_states, *args) + value = attn.to_v(encoder_hidden_states, *args) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -1232,9 +1236,7 @@ class AttnProcessor2_0: hidden_states = hidden_states.to(query.dtype) # linear proj - hidden_states = ( - attn.to_out[0](hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_out[0](hidden_states) - ) + hidden_states = attn.to_out[0](hidden_states, *args) # dropout hidden_states = attn.to_out[1](hidden_states)