1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/cfgzero/__init__.py
Vladimir Mandic e8b5ea3847 major refactor: remove backend original
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-07-05 13:16:46 -04:00

64 lines
2.8 KiB
Python

# reference: <https://github.com/WeichenFan/CFG-Zero-star>
from modules import shared, processing, sd_models
orig_pipeline = None
supported = [
'FluxPipeline',
'CogView4Pipeline',
'StableDiffusion3Pipeline',
'HiDreamImagePipeline',
'WanPipeline',
'HunyuanVideoPipeline',
]
def apply(p: processing.StableDiffusionProcessing):
if not shared.opts.cfgzero_enabled:
return None
cls = shared.sd_model.__class__.__name__ if shared.sd_loaded else 'None'
if 'CFGZero' in cls:
unapply()
if cls not in supported:
return None
global orig_pipeline # pylint: disable=global-statement
orig_pipeline = shared.sd_model
if cls == 'FluxPipeline':
from diffusers import pipelines
from modules.cfgzero.flux_pipeline import FluxCFGZeroPipeline
shared.sd_model = sd_models.switch_pipe(FluxCFGZeroPipeline, shared.sd_model)
pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["fluxcfgzero"] = FluxCFGZeroPipeline
pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["fluxcfgzero"] = pipelines.FluxImg2ImgPipeline
pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["fluxcfgzero"] = pipelines.FluxInpaintPipeline
if cls == 'CogView4Pipeline':
from modules.cfgzero.cogview4_pipeline import CogView4CFGZeroPipeline
shared.sd_model = sd_models.switch_pipe(CogView4CFGZeroPipeline, shared.sd_model)
if cls == 'StableDiffusion3Pipeline':
from modules.cfgzero.sd3_pipeline import StableDiffusion3CFGZeroPipeline
shared.sd_model = sd_models.switch_pipe(StableDiffusion3CFGZeroPipeline, shared.sd_model)
if cls == 'HiDreamImagePipeline':
from modules.cfgzero.hidream_pipeline import HiDreamImageCFGZeroPipeline
shared.sd_model = sd_models.switch_pipe(HiDreamImageCFGZeroPipeline, shared.sd_model)
if cls == 'WanPipeline':
from modules.cfgzero.wan_t2v_pipeline import WanCFGZeroPipeline
shared.sd_model = sd_models.switch_pipe(WanCFGZeroPipeline, shared.sd_model)
if cls == 'HunyuanVideoPipeline':
from modules.cfgzero.hunyuan_t2v_pipeline import HunyuanVideoCFGZeroPipeline
shared.sd_model = sd_models.switch_pipe(HunyuanVideoCFGZeroPipeline, shared.sd_model)
shared.log.debug(f'Apply CFGZero: cls={cls} init={shared.opts.cfgzero_enabled} star={shared.opts.cfgzero_star} steps={shared.opts.cfgzero_steps}')
p.task_args['use_zero_init'] = shared.opts.cfgzero_enabled
p.task_args['use_cfg_zero_star'] = shared.opts.cfgzero_star
p.task_args['zero_steps'] = int(shared.opts.cfgzero_steps)
p.extra_generation_params['CFGZero'] = True
def unapply():
global orig_pipeline # pylint: disable=global-statement
if orig_pipeline is not None:
shared.sd_model = orig_pipeline
orig_pipeline = None
return shared.sd_model.__class__