mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
NPU Adaption for FLUX (#9751)
* NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX --------- Co-authored-by: 蒋硕 <jiangshuo9@h-partners.com>
This commit is contained in:
@@ -57,6 +57,7 @@ from diffusers.utils import (
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_torch_npu_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
@@ -68,6 +69,12 @@ check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
|
||||
torch.npu.config.allow_internal_format = False
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
@@ -189,6 +196,8 @@ def log_validation(
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch_npu.npu.empty_cache()
|
||||
|
||||
return images
|
||||
|
||||
@@ -1035,7 +1044,9 @@ def main(args):
|
||||
cur_class_images = len(list(class_images_dir.iterdir()))
|
||||
|
||||
if cur_class_images < args.num_class_images:
|
||||
has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
|
||||
has_supported_fp16_accelerator = (
|
||||
torch.cuda.is_available() or torch.backends.mps.is_available() or is_torch_npu_available()
|
||||
)
|
||||
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
|
||||
if args.prior_generation_precision == "fp32":
|
||||
torch_dtype = torch.float32
|
||||
@@ -1073,6 +1084,8 @@ def main(args):
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch_npu.npu.empty_cache()
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
@@ -1354,6 +1367,8 @@ def main(args):
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch_npu.npu.empty_cache()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
@@ -1719,7 +1734,10 @@ def main(args):
|
||||
)
|
||||
if not args.train_text_encoder:
|
||||
del text_encoder_one, text_encoder_two
|
||||
torch.cuda.empty_cache()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch_npu.npu.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# Save the lora layers
|
||||
|
||||
@@ -1893,6 +1893,112 @@ class FluxAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxAttnProcessor2_0_NPU:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
if encoder_hidden_states is not None:
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
if query.dtype in (torch.float16, torch.bfloat16):
|
||||
hidden_states = torch_npu.npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn.heads,
|
||||
input_layout="BNSD",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]),
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedFluxAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
@@ -1987,6 +2093,117 @@ class FusedFluxAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedFluxAttnProcessor2_0_NPU:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
# `sample` projections.
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
# `context` projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
||||
split_size = encoder_qkv.shape[-1] // 3
|
||||
(
|
||||
encoder_hidden_states_query_proj,
|
||||
encoder_hidden_states_key_proj,
|
||||
encoder_hidden_states_value_proj,
|
||||
) = torch.split(encoder_qkv, split_size, dim=-1)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
if query.dtype in (torch.float16, torch.bfloat16):
|
||||
hidden_states = torch_npu.npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn.heads,
|
||||
input_layout="BNSD",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]),
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
|
||||
@@ -27,11 +27,13 @@ from ...models.attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxAttnProcessor2_0_NPU,
|
||||
FusedFluxAttnProcessor2_0,
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.import_utils import is_torch_npu_available
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -64,7 +66,10 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||
|
||||
processor = FluxAttnProcessor2_0()
|
||||
if is_torch_npu_available():
|
||||
processor = FluxAttnProcessor2_0_NPU()
|
||||
else:
|
||||
processor = FluxAttnProcessor2_0()
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
|
||||
Reference in New Issue
Block a user