mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Update ptxla training (#9864)
* update ptxla example --------- Co-authored-by: Juan Acevedo <jfacevedo@google.com> Co-authored-by: Pei Zhang <zpcore@gmail.com> Co-authored-by: Pei Zhang <piz@google.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Pei Zhang <pei@Peis-MacBook-Pro.local> Co-authored-by: hlky <hlky@hlky.ac>
This commit is contained in:
@@ -7,13 +7,14 @@ It has been tested on v4 and v5p TPU versions. Training code has been tested on
|
||||
This script implements Distributed Data Parallel using GSPMD feature in XLA compiler
|
||||
where we shard the input batches over the TPU devices.
|
||||
|
||||
As of 9-11-2024, these are some expected step times.
|
||||
As of 10-31-2024, these are some expected step times.
|
||||
|
||||
| accelerator | global batch size | step time (seconds) |
|
||||
| ----------- | ----------------- | --------- |
|
||||
| v5p-128 | 1024 | 0.245 |
|
||||
| v5p-256 | 2048 | 0.234 |
|
||||
| v5p-512 | 4096 | 0.2498 |
|
||||
| v5p-512 | 16384 | 1.01 |
|
||||
| v5p-256 | 8192 | 1.01 |
|
||||
| v5p-128 | 4096 | 1.0 |
|
||||
| v5p-64 | 2048 | 1.01 |
|
||||
|
||||
## Create TPU
|
||||
|
||||
@@ -43,8 +44,9 @@ Install PyTorch and PyTorch/XLA nightly versions:
|
||||
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
|
||||
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
|
||||
--command='
|
||||
pip3 install --pre torch==2.5.0.dev20240905+cpu torchvision==0.20.0.dev20240905+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240905-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
pip3 install --pre torch==2.6.0.dev20241031+cpu torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241031.cxx11-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
'
|
||||
```
|
||||
|
||||
@@ -88,17 +90,18 @@ are fixed.
|
||||
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
|
||||
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
|
||||
--command='
|
||||
export XLA_DISABLE_FUNCTIONALIZATION=1
|
||||
export XLA_DISABLE_FUNCTIONALIZATION=0
|
||||
export PROFILE_DIR=/tmp/
|
||||
export CACHE_DIR=/tmp/
|
||||
export DATASET_NAME=lambdalabs/naruto-blip-captions
|
||||
export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p
|
||||
export TRAIN_STEPS=50
|
||||
export OUTPUT_DIR=/tmp/trained-model/
|
||||
python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=4 --loader_prefetch_size=4 --device_prefetch_size=4'
|
||||
|
||||
python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4'
|
||||
```
|
||||
|
||||
Pass `--print_loss` if you would like to see the loss printed at every step. Be aware that printing the loss at every step disrupts the optimized flow execution, thus the step time will be longer.
|
||||
|
||||
### Environment Envs Explained
|
||||
|
||||
* `XLA_DISABLE_FUNCTIONALIZATION`: To optimize the performance for AdamW optimizer.
|
||||
|
||||
@@ -140,33 +140,43 @@ class TrainSD:
|
||||
self.optimizer.step()
|
||||
|
||||
def start_training(self):
|
||||
times = []
|
||||
last_time = time.time()
|
||||
step = 0
|
||||
while True:
|
||||
if self.global_step >= self.args.max_train_steps:
|
||||
xm.mark_step()
|
||||
break
|
||||
if step == 4 and PROFILE_DIR is not None:
|
||||
xm.wait_device_ops()
|
||||
xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration)
|
||||
dataloader_exception = False
|
||||
measure_start_step = args.measure_start_step
|
||||
assert measure_start_step < self.args.max_train_steps
|
||||
total_time = 0
|
||||
for step in range(0, self.args.max_train_steps):
|
||||
try:
|
||||
batch = next(self.dataloader)
|
||||
except Exception as e:
|
||||
dataloader_exception = True
|
||||
print(e)
|
||||
break
|
||||
if step == measure_start_step and PROFILE_DIR is not None:
|
||||
xm.wait_device_ops()
|
||||
xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration)
|
||||
last_time = time.time()
|
||||
loss = self.step_fn(batch["pixel_values"], batch["input_ids"])
|
||||
step_time = time.time() - last_time
|
||||
if step >= 10:
|
||||
times.append(step_time)
|
||||
print(f"step: {step}, step_time: {step_time}")
|
||||
if step % 5 == 0:
|
||||
print(f"step: {step}, loss: {loss}")
|
||||
last_time = time.time()
|
||||
self.global_step += 1
|
||||
step += 1
|
||||
# print(f"Average step time: {sum(times)/len(times)}")
|
||||
xm.wait_device_ops()
|
||||
|
||||
def print_loss_closure(step, loss):
|
||||
print(f"Step: {step}, Loss: {loss}")
|
||||
|
||||
if args.print_loss:
|
||||
xm.add_step_closure(
|
||||
print_loss_closure,
|
||||
args=(
|
||||
self.global_step,
|
||||
loss,
|
||||
),
|
||||
)
|
||||
xm.mark_step()
|
||||
if not dataloader_exception:
|
||||
xm.wait_device_ops()
|
||||
total_time = time.time() - last_time
|
||||
print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}")
|
||||
else:
|
||||
print("dataloader exception happen, skip result")
|
||||
return
|
||||
|
||||
def step_fn(
|
||||
self,
|
||||
@@ -180,7 +190,10 @@ class TrainSD:
|
||||
noise = torch.randn_like(latents).to(self.device, dtype=self.weight_dtype)
|
||||
bsz = latents.shape[0]
|
||||
timesteps = torch.randint(
|
||||
0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
|
||||
0,
|
||||
self.noise_scheduler.config.num_train_timesteps,
|
||||
(bsz,),
|
||||
device=latents.device,
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
@@ -224,9 +237,6 @@ class TrainSD:
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
|
||||
)
|
||||
parser.add_argument("--profile_duration", type=int, default=10000, help="Profile duration in ms")
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
@@ -258,12 +268,6 @@ def parse_args():
|
||||
" or to a folder containing files that 🤗 Datasets can understand."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The config of the Dataset, leave as None if there's only one config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir",
|
||||
type=str,
|
||||
@@ -283,15 +287,6 @@ def parse_args():
|
||||
default="text",
|
||||
help="The column of the dataset containing a caption or a list of captions.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
@@ -304,7 +299,6 @@ def parse_args():
|
||||
default=None,
|
||||
help="The directory where the downloaded models and datasets will be stored.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
@@ -374,12 +368,19 @@ def parse_args():
|
||||
default=1,
|
||||
help=("Number of subprocesses to use for data loading to cpu."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--loader_prefetch_factor",
|
||||
type=int,
|
||||
default=2,
|
||||
help=("Number of batches loaded in advance by each worker."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device_prefetch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help=("Number of subprocesses to use for data loading to tpu from cpu. "),
|
||||
)
|
||||
parser.add_argument("--measure_start_step", type=int, default=10, help="Step to start profiling.")
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
@@ -394,12 +395,8 @@ def parse_args():
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
choices=["no", "bf16"],
|
||||
help=("Whether to use mixed precision. Bf16 requires PyTorch >= 1.10"),
|
||||
)
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
@@ -409,6 +406,12 @@ def parse_args():
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print_loss",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=("Print loss at every step."),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -436,7 +439,6 @@ def load_dataset(args):
|
||||
# Downloading and loading a dataset from the hub.
|
||||
dataset = datasets.load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
data_dir=args.train_data_dir,
|
||||
)
|
||||
@@ -481,9 +483,7 @@ def main(args):
|
||||
_ = xp.start_server(PORT)
|
||||
|
||||
num_devices = xr.global_runtime_device_count()
|
||||
device_ids = np.arange(num_devices)
|
||||
mesh_shape = (num_devices, 1)
|
||||
mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
|
||||
mesh = xs.get_1d_mesh("data")
|
||||
xs.set_global_mesh(mesh)
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
@@ -520,6 +520,7 @@ def main(args):
|
||||
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
|
||||
|
||||
unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward)
|
||||
unet.enable_xla_flash_attention(partition_spec=("data", None, None, None))
|
||||
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
@@ -530,15 +531,12 @@ def main(args):
|
||||
# as these weights are only used for inference, keeping weights in full
|
||||
# precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
if args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
device = xm.xla_device()
|
||||
print("device: ", device)
|
||||
print("weight_dtype: ", weight_dtype)
|
||||
|
||||
# Move text_encode and vae to device and cast to weight_dtype
|
||||
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
||||
vae = vae.to(device, dtype=weight_dtype)
|
||||
unet = unet.to(device, dtype=weight_dtype)
|
||||
@@ -606,24 +604,27 @@ def main(args):
|
||||
collate_fn=collate_fn,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
batch_size=args.train_batch_size,
|
||||
prefetch_factor=args.loader_prefetch_factor,
|
||||
)
|
||||
|
||||
train_dataloader = pl.MpDeviceLoader(
|
||||
train_dataloader,
|
||||
device,
|
||||
input_sharding={
|
||||
"pixel_values": xs.ShardingSpec(mesh, ("x", None, None, None), minibatch=True),
|
||||
"input_ids": xs.ShardingSpec(mesh, ("x", None), minibatch=True),
|
||||
"pixel_values": xs.ShardingSpec(mesh, ("data", None, None, None), minibatch=True),
|
||||
"input_ids": xs.ShardingSpec(mesh, ("data", None), minibatch=True),
|
||||
},
|
||||
loader_prefetch_size=args.loader_prefetch_size,
|
||||
device_prefetch_size=args.device_prefetch_size,
|
||||
)
|
||||
|
||||
num_hosts = xr.process_count()
|
||||
num_devices_per_host = num_devices // num_hosts
|
||||
if xm.is_master_ordinal():
|
||||
print("***** Running training *****")
|
||||
print(f"Instantaneous batch size per device = {args.train_batch_size}")
|
||||
print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host }")
|
||||
print(
|
||||
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_devices}"
|
||||
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}"
|
||||
)
|
||||
print(f" Total optimization steps = {args.max_train_steps}")
|
||||
|
||||
|
||||
@@ -20,8 +20,8 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..image_processor import IPAdapterMaskProcessor
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
|
||||
from ..utils import deprecate, is_torch_xla_available, logging
|
||||
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available
|
||||
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
|
||||
|
||||
|
||||
@@ -36,6 +36,15 @@ if is_xformers_available():
|
||||
else:
|
||||
xformers = None
|
||||
|
||||
if is_torch_xla_available():
|
||||
# flash attention pallas kernel is introduced in the torch_xla 2.3 release.
|
||||
if is_torch_xla_version(">", "2.2"):
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
from torch_xla.runtime import is_spmd
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class Attention(nn.Module):
|
||||
@@ -275,6 +284,33 @@ class Attention(nn.Module):
|
||||
)
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_use_xla_flash_attention(
|
||||
self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None
|
||||
) -> None:
|
||||
r"""
|
||||
Set whether to use xla flash attention from `torch_xla` or not.
|
||||
|
||||
Args:
|
||||
use_xla_flash_attention (`bool`):
|
||||
Whether to use pallas flash attention kernel from `torch_xla` or not.
|
||||
partition_spec (`Tuple[]`, *optional*):
|
||||
Specify the partition specification if using SPMD. Otherwise None.
|
||||
"""
|
||||
if use_xla_flash_attention:
|
||||
if not is_torch_xla_available:
|
||||
raise "torch_xla is not available"
|
||||
elif is_torch_xla_version("<", "2.3"):
|
||||
raise "flash attention pallas kernel is supported from torch_xla version 2.3"
|
||||
elif is_spmd() and is_torch_xla_version("<", "2.4"):
|
||||
raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
|
||||
else:
|
||||
processor = XLAFlashAttnProcessor2_0(partition_spec)
|
||||
else:
|
||||
processor = (
|
||||
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
||||
)
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
|
||||
r"""
|
||||
Set whether to use npu flash attention from `torch_npu` or not.
|
||||
@@ -2845,6 +2881,122 @@ class AttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class XLAFlashAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
|
||||
"""
|
||||
|
||||
def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
if is_torch_xla_version("<", "2.3"):
|
||||
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
|
||||
if is_spmd() and is_torch_xla_version("<", "2.4"):
|
||||
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
|
||||
self.partition_spec = partition_spec
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
if all(tensor.shape[2] >= 4096 for tensor in [query, key, value]):
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
|
||||
# Convert mask to float and replace 0s with -inf and 1s with 0
|
||||
attention_mask = (
|
||||
attention_mask.float()
|
||||
.masked_fill(attention_mask == 0, float("-inf"))
|
||||
.masked_fill(attention_mask == 1, float(0.0))
|
||||
)
|
||||
|
||||
# Apply attention mask to key
|
||||
key = key + attention_mask
|
||||
query /= math.sqrt(query.shape[3])
|
||||
partition_spec = self.partition_spec if is_spmd() else None
|
||||
hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096."
|
||||
)
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MochiVaeAttnProcessor2_0:
|
||||
r"""
|
||||
Attention processor used in Mochi VAE.
|
||||
@@ -5226,6 +5378,7 @@ AttentionProcessor = Union[
|
||||
FusedCogVideoXAttnProcessor2_0,
|
||||
XFormersAttnAddedKVProcessor,
|
||||
XFormersAttnProcessor,
|
||||
XLAFlashAttnProcessor2_0,
|
||||
AttnProcessorNPU,
|
||||
AttnProcessor2_0,
|
||||
MochiVaeAttnProcessor2_0,
|
||||
|
||||
@@ -208,6 +208,35 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
"""
|
||||
self.set_use_npu_flash_attention(False)
|
||||
|
||||
def set_use_xla_flash_attention(
|
||||
self, use_xla_flash_attention: bool, partition_spec: Optional[Callable] = None
|
||||
) -> None:
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_use_xla_flash_attention method
|
||||
# gets the message
|
||||
def fn_recursive_set_flash_attention(module: torch.nn.Module):
|
||||
if hasattr(module, "set_use_xla_flash_attention"):
|
||||
module.set_use_xla_flash_attention(use_xla_flash_attention, partition_spec)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_flash_attention(child)
|
||||
|
||||
for module in self.children():
|
||||
if isinstance(module, torch.nn.Module):
|
||||
fn_recursive_set_flash_attention(module)
|
||||
|
||||
def enable_xla_flash_attention(self, partition_spec: Optional[Callable] = None):
|
||||
r"""
|
||||
Enable the flash attention pallals kernel for torch_xla.
|
||||
"""
|
||||
self.set_use_xla_flash_attention(True, partition_spec)
|
||||
|
||||
def disable_xla_flash_attention(self):
|
||||
r"""
|
||||
Disable the flash attention pallals kernel for torch_xla.
|
||||
"""
|
||||
self.set_use_xla_flash_attention(False)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, valid: bool, attention_op: Optional[Callable] = None
|
||||
) -> None:
|
||||
|
||||
@@ -86,6 +86,7 @@ from .import_utils import (
|
||||
is_torch_npu_available,
|
||||
is_torch_version,
|
||||
is_torch_xla_available,
|
||||
is_torch_xla_version,
|
||||
is_torchsde_available,
|
||||
is_torchvision_available,
|
||||
is_transformers_available,
|
||||
|
||||
@@ -700,6 +700,21 @@ def is_torch_version(operation: str, version: str):
|
||||
return compare_versions(parse(_torch_version), operation, version)
|
||||
|
||||
|
||||
def is_torch_xla_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current torch_xla version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A string version of torch_xla
|
||||
"""
|
||||
if not is_torch_xla_available:
|
||||
return False
|
||||
return compare_versions(parse(_torch_xla_version), operation, version)
|
||||
|
||||
|
||||
def is_transformers_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current Transformers version to a given reference with an operation.
|
||||
|
||||
Reference in New Issue
Block a user