1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Implement CustomDiffusionAttnProcessor2_0. (#4604)

* Implement `CustomDiffusionAttnProcessor2_0`

* Doc-strings and type annotations for `CustomDiffusionAttnProcessor2_0`. (#1)

* Update attnprocessor.md

* Update attention_processor.py

* Interops for `CustomDiffusionAttnProcessor2_0`.

* Formatted `attention_processor.py`.

* Formatted doc-string in `attention_processor.py`

* Conditional CustomDiffusion2_0 for training example.

* Remove unnecessary reference impl in comments.

* Fix `save_attn_procs`.
This commit is contained in:
Ruoxi
2023-09-18 05:49:00 -07:00
committed by GitHub
parent 7b39f43c06
commit 16b9a57d29
4 changed files with 139 additions and 7 deletions

View File

@@ -51,7 +51,11 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor
from diffusers.models.attention_processor import (
CustomDiffusionAttnProcessor,
CustomDiffusionAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
@@ -870,7 +874,9 @@ def main(args):
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
attention_class = CustomDiffusionAttnProcessor
attention_class = (
CustomDiffusionAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else CustomDiffusionAttnProcessor
)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers