mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Implement CustomDiffusionAttnProcessor2_0. (#4604)
* Implement `CustomDiffusionAttnProcessor2_0` * Doc-strings and type annotations for `CustomDiffusionAttnProcessor2_0`. (#1) * Update attnprocessor.md * Update attention_processor.py * Interops for `CustomDiffusionAttnProcessor2_0`. * Formatted `attention_processor.py`. * Formatted doc-string in `attention_processor.py` * Conditional CustomDiffusion2_0 for training example. * Remove unnecessary reference impl in comments. * Fix `save_attn_procs`.
This commit is contained in:
@@ -51,7 +51,11 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor
|
||||
from diffusers.models.attention_processor import (
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
@@ -870,7 +874,9 @@ def main(args):
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
attention_class = CustomDiffusionAttnProcessor
|
||||
attention_class = (
|
||||
CustomDiffusionAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else CustomDiffusionAttnProcessor
|
||||
)
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
|
||||
Reference in New Issue
Block a user