mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
56 lines
2.6 KiB
Python
56 lines
2.6 KiB
Python
from diffusers.pipelines import StableDiffusionPipeline, StableDiffusionXLPipeline # pylint: disable=unused-import
|
|
from modules import shared, processing, sd_models
|
|
from modules.pag.pipe_sd import StableDiffusionPAGPipeline
|
|
from modules.pag.pipe_sdxl import StableDiffusionXLPAGPipeline
|
|
from modules.control.units import detect
|
|
|
|
|
|
orig_pipeline = None
|
|
|
|
|
|
def apply(p: processing.StableDiffusionProcessing): # pylint: disable=arguments-differ
|
|
global orig_pipeline # pylint: disable=global-statement
|
|
cls = shared.sd_model.__class__ if shared.sd_loaded else None
|
|
if cls == StableDiffusionPAGPipeline or cls == StableDiffusionXLPAGPipeline:
|
|
cls = unapply()
|
|
if p.pag_scale == 0:
|
|
return
|
|
if cls is not None and 'PAG' in cls.__name__:
|
|
pass
|
|
elif detect.is_sd15(cls):
|
|
if sd_models.get_diffusers_task(shared.sd_model) != sd_models.DiffusersTaskType.TEXT_2_IMAGE:
|
|
shared.log.warning(f'PAG: pipeline={cls.__name__} not implemented')
|
|
return None
|
|
orig_pipeline = shared.sd_model
|
|
shared.sd_model = sd_models.switch_pipe(StableDiffusionPAGPipeline, shared.sd_model)
|
|
elif detect.is_sdxl(cls):
|
|
if sd_models.get_diffusers_task(shared.sd_model) != sd_models.DiffusersTaskType.TEXT_2_IMAGE:
|
|
shared.log.warning(f'PAG: pipeline={cls.__name__} not implemented')
|
|
return None
|
|
orig_pipeline = shared.sd_model
|
|
shared.sd_model = sd_models.switch_pipe(StableDiffusionXLPAGPipeline, shared.sd_model)
|
|
elif detect.is_f1(cls):
|
|
p.task_args['true_cfg_scale'] = p.pag_scale
|
|
else:
|
|
# shared.log.warning(f'PAG: pipeline={cls.__name__} required={StableDiffusionPipeline.__name__}')
|
|
return None
|
|
|
|
p.task_args['pag_scale'] = p.pag_scale
|
|
p.task_args['pag_adaptive_scaling'] = p.pag_adaptive
|
|
p.task_args['pag_adaptive_scale'] = p.pag_adaptive
|
|
pag_applied_layers = shared.opts.pag_apply_layers
|
|
pag_applied_layers_index = pag_applied_layers.split() if len(pag_applied_layers) > 0 else []
|
|
pag_applied_layers_index = [p.strip() for p in pag_applied_layers_index]
|
|
p.task_args['pag_applied_layers_index'] = pag_applied_layers_index if len(pag_applied_layers_index) > 0 else ['m0'] # Available layers: d[0-5], m[0], u[0-8]
|
|
p.extra_generation_params["CFG true"] = p.pag_scale
|
|
p.extra_generation_params["CFG adaptive"] = p.pag_adaptive
|
|
# shared.log.debug(f'{c}: args={p.task_args}')
|
|
|
|
|
|
def unapply():
|
|
global orig_pipeline # pylint: disable=global-statement
|
|
if orig_pipeline is not None:
|
|
shared.sd_model = orig_pipeline
|
|
orig_pipeline = None
|
|
return shared.sd_model.__class__
|