mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
47 lines
1.7 KiB
Python
47 lines
1.7 KiB
Python
from modules import shared, sd_models, devices, attention
|
|
from .linfusion import LinFusion
|
|
from .attention import GeneralizedLinearAttention
|
|
|
|
|
|
applied: LinFusion = None
|
|
|
|
|
|
def detect(pipeline):
|
|
if pipeline.__class__.__name__ == 'StableDiffusionXLPipeline':
|
|
return "Yuanshi/LinFusion-XL"
|
|
if pipeline.__class__.__name__ == 'StableDiffusionPipeline':
|
|
return "Yuanshi/LinFusion-1-5"
|
|
return None
|
|
|
|
|
|
def apply(pipeline, pretrained: bool = True):
|
|
global applied # pylint: disable=global-statement
|
|
if not shared.opts.enable_linfusion:
|
|
return
|
|
if applied is not None:
|
|
return
|
|
# linfusion = LinFusion.construct_for(pipeline=pipeline)
|
|
if not pretrained:
|
|
model_path = None
|
|
default_config = LinFusion.get_default_config(unet=pipeline.unet)
|
|
applied = LinFusion(**default_config).to(device=pipeline.unet.device, dtype=pipeline.unet.dtype)
|
|
applied.mount_to(unet=pipeline.unet)
|
|
else:
|
|
model_path = detect(pipeline)
|
|
if model_path is None:
|
|
shared.log.error('LinFusion: unsupported model type')
|
|
return
|
|
applied = LinFusion.from_pretrained(model_path, cache_dir=shared.opts.hfcache_dir).to(device=pipeline.unet.device, dtype=pipeline.unet.dtype)
|
|
applied.mount_to(unet=pipeline.unet)
|
|
shared.log.info(f'Applying LinFusion: class={applied.__class__.__name__} model="{model_path}" modules={len(applied.modules_dict)}')
|
|
|
|
|
|
def unapply(pipeline):
|
|
global applied # pylint: disable=global-statement
|
|
if applied is None:
|
|
return
|
|
# shared.log.debug('LinFusion: unapply')
|
|
attention.set_diffusers_attention(pipeline)
|
|
devices.torch_gc()
|
|
applied = None
|