mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add Ascend NPU support for SDXL fine-tuning and fix the model saving bug when using DeepSpeed. (#7816)
* Add Ascend NPU support for SDXL fine-tuning and fix the model saving bug when using DeepSpeed. * fix check code quality * Decouple the NPU flash attention and make it an independent module. * add doc and unit tests for npu flash attention. --------- Co-authored-by: mhh001 <mahonghao1@huawei.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
committed by
GitHub
parent
3e35628873
commit
58237364b1
@@ -55,3 +55,6 @@ An attention processor is a class for applying different types of attention mech
|
||||
|
||||
## XFormersAttnProcessor
|
||||
[[autodoc]] models.attention_processor.XFormersAttnProcessor
|
||||
|
||||
## AttnProcessorNPU
|
||||
[[autodoc]] models.attention_processor.AttnProcessorNPU
|
||||
|
||||
@@ -32,7 +32,7 @@ import torch.utils.checkpoint
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
@@ -53,7 +53,7 @@ from diffusers import (
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
@@ -64,6 +64,8 @@ if is_wandb_available():
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
torch.npu.config.allow_internal_format = False
|
||||
|
||||
|
||||
def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
|
||||
@@ -471,6 +473,9 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--set_grads_to_none",
|
||||
action="store_true",
|
||||
@@ -936,6 +941,13 @@ def main(args):
|
||||
text_encoder_two.requires_grad_(False)
|
||||
controlnet.train()
|
||||
|
||||
if args.enable_npu_flash_attention:
|
||||
if is_torch_npu_available():
|
||||
logger.info("npu flash attention enabled.")
|
||||
unet.enable_npu_flash_attention()
|
||||
else:
|
||||
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
@@ -1235,7 +1247,8 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process:
|
||||
# DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
|
||||
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
|
||||
@@ -32,7 +32,7 @@ import torch.utils.checkpoint
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||
from accelerate.utils import DistributedDataParallelKwargs, DistributedType, ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
@@ -60,7 +60,7 @@ from diffusers.utils import (
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
@@ -68,6 +68,8 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
torch.npu.config.allow_internal_format = False
|
||||
|
||||
|
||||
def save_model_card(
|
||||
@@ -419,6 +421,9 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
|
||||
)
|
||||
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
@@ -623,6 +628,13 @@ def main(args):
|
||||
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
if args.enable_npu_flash_attention:
|
||||
if is_torch_npu_available():
|
||||
logger.info("npu flash attention enabled.")
|
||||
unet.enable_npu_flash_attention()
|
||||
else:
|
||||
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
@@ -1149,7 +1161,8 @@ def main(args):
|
||||
accelerator.log({"train_loss": train_loss}, step=global_step)
|
||||
train_loss = 0.0
|
||||
|
||||
if accelerator.is_main_process:
|
||||
# DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
|
||||
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
|
||||
@@ -18,8 +18,12 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import deprecate
|
||||
from ..utils.import_utils import is_torch_npu_available
|
||||
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
|
||||
ACTIVATION_FUNCTIONS = {
|
||||
"swish": nn.SiLU(),
|
||||
"silu": nn.SiLU(),
|
||||
@@ -98,9 +102,13 @@ class GEGLU(nn.Module):
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
||||
return hidden_states * self.gelu(gate)
|
||||
hidden_states = self.proj(hidden_states)
|
||||
if is_torch_npu_available():
|
||||
# using torch_npu.npu_geglu can run faster and save memory on NPU.
|
||||
return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]
|
||||
else:
|
||||
hidden_states, gate = hidden_states.chunk(2, dim=-1)
|
||||
return hidden_states * self.gelu(gate)
|
||||
|
||||
|
||||
class ApproximateGELU(nn.Module):
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import math
|
||||
from importlib import import_module
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
@@ -21,13 +22,15 @@ from torch import nn
|
||||
|
||||
from ..image_processor import IPAdapterMaskProcessor
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.import_utils import is_xformers_available
|
||||
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .lora import LoRALinearLayer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
@@ -209,6 +212,23 @@ class Attention(nn.Module):
|
||||
)
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
|
||||
r"""
|
||||
Set whether to use npu flash attention from `torch_npu` or not.
|
||||
|
||||
"""
|
||||
if use_npu_flash_attention:
|
||||
processor = AttnProcessorNPU()
|
||||
else:
|
||||
# set attention processor
|
||||
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
||||
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
||||
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
||||
processor = (
|
||||
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
||||
)
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
||||
) -> None:
|
||||
@@ -1207,6 +1227,116 @@ class XFormersAttnProcessor:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnProcessorNPU:
|
||||
|
||||
r"""
|
||||
Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
|
||||
fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
|
||||
not significant.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not is_torch_npu_available():
|
||||
raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
if query.dtype in (torch.float16, torch.bfloat16):
|
||||
hidden_states = torch_npu.npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn.heads,
|
||||
input_layout="BNSD",
|
||||
pse=None,
|
||||
atten_mask=attention_mask,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]),
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
else:
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
|
||||
@@ -272,6 +272,36 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
if self._supports_gradient_checkpointing:
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
||||
|
||||
def set_use_npu_flash_attention(self, valid: bool) -> None:
|
||||
r"""
|
||||
Set the switch for the npu flash attention.
|
||||
"""
|
||||
|
||||
def fn_recursive_set_npu_flash_attention(module: torch.nn.Module):
|
||||
if hasattr(module, "set_use_npu_flash_attention"):
|
||||
module.set_use_npu_flash_attention(valid)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_npu_flash_attention(child)
|
||||
|
||||
for module in self.children():
|
||||
if isinstance(module, torch.nn.Module):
|
||||
fn_recursive_set_npu_flash_attention(module)
|
||||
|
||||
def enable_npu_flash_attention(self) -> None:
|
||||
r"""
|
||||
Enable npu flash attention from torch_npu
|
||||
|
||||
"""
|
||||
self.set_use_npu_flash_attention(True)
|
||||
|
||||
def disable_npu_flash_attention(self) -> None:
|
||||
r"""
|
||||
disable npu flash attention from torch_npu
|
||||
|
||||
"""
|
||||
self.set_use_npu_flash_attention(False)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, valid: bool, attention_op: Optional[Callable] = None
|
||||
) -> None:
|
||||
|
||||
@@ -30,9 +30,14 @@ from huggingface_hub.utils import is_jinja_available
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
AttnProcessorNPU,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import is_xformers_available, logging
|
||||
from diffusers.utils import is_torch_npu_available, is_xformers_available, logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
get_python_version,
|
||||
@@ -300,6 +305,53 @@ class ModelTesterMixin:
|
||||
|
||||
assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "npu" or not is_torch_npu_available(),
|
||||
reason="torch npu flash attention is only available with NPU and `torch_npu` installed",
|
||||
)
|
||||
def test_set_torch_npu_flash_attn_processor_determinism(self):
|
||||
torch.use_deterministic_algorithms(False)
|
||||
if self.forward_requires_fresh_args:
|
||||
model = self.model_class(**self.init_dict)
|
||||
else:
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
if not hasattr(model, "set_attn_processor"):
|
||||
# If not has `set_attn_processor`, skip test
|
||||
return
|
||||
|
||||
model.set_default_attn_processor()
|
||||
assert all(type(proc) == AttnProcessorNPU for proc in model.attn_processors.values())
|
||||
with torch.no_grad():
|
||||
if self.forward_requires_fresh_args:
|
||||
output = model(**self.inputs_dict(0))[0]
|
||||
else:
|
||||
output = model(**inputs_dict)[0]
|
||||
|
||||
model.enable_npu_flash_attention()
|
||||
assert all(type(proc) == AttnProcessorNPU for proc in model.attn_processors.values())
|
||||
with torch.no_grad():
|
||||
if self.forward_requires_fresh_args:
|
||||
output_2 = model(**self.inputs_dict(0))[0]
|
||||
else:
|
||||
output_2 = model(**inputs_dict)[0]
|
||||
|
||||
model.set_attn_processor(AttnProcessorNPU())
|
||||
assert all(type(proc) == AttnProcessorNPU for proc in model.attn_processors.values())
|
||||
with torch.no_grad():
|
||||
if self.forward_requires_fresh_args:
|
||||
output_3 = model(**self.inputs_dict(0))[0]
|
||||
else:
|
||||
output_3 = model(**inputs_dict)[0]
|
||||
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
assert torch.allclose(output, output_2, atol=self.base_precision)
|
||||
assert torch.allclose(output, output_3, atol=self.base_precision)
|
||||
assert torch.allclose(output_2, output_3, atol=self.base_precision)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
|
||||
Reference in New Issue
Block a user