From cd0a4a82cf8625b96e2889afee2fce5811b35c05 Mon Sep 17 00:00:00 2001 From: Leo Jiang <74156916+leisuzz@users.noreply.github.com> Date: Thu, 6 Feb 2025 06:59:58 -0700 Subject: [PATCH] [bugfix] NPU Adaption for Sana (#10724) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * [bugfix]NPU Adaption for Sanna --------- Co-authored-by: J石页 Co-authored-by: Sayak Paul --- examples/dreambooth/train_dreambooth_lora_sana.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py index 9e69bd6a66..798980e86b 100644 --- a/examples/dreambooth/train_dreambooth_lora_sana.py +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -995,7 +995,8 @@ def main(args): if args.enable_npu_flash_attention: if is_torch_npu_available(): logger.info("npu flash attention enabled.") - transformer.enable_npu_flash_attention() + for block in transformer.transformer_blocks: + block.attn2.set_use_npu_flash_attention(True) else: raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")