import gradio as gr from modules import scripts_manager, processing, shared, sd_models registered = False class Script(scripts_manager.Script): def __init__(self): super().__init__() self.orig_pipe = None self.register() def title(self): return 'APG: Adaptive Projected Guidance' def show(self, is_img2img): return not is_img2img def ui(self, _is_img2img): # ui elements with gr.Row(): gr.HTML('  APG: Adaptive Projected Guidance
') with gr.Row(): eta = gr.Slider(label="ETA", value=1.0, minimum=0, maximum=2.0, step=0.05) momentum = gr.Slider(label="Momentum", value=-0.50, minimum=-1.0, maximum=1.0, step=0.05) threshold = gr.Slider(label="Threshold", value=0.0, minimum=0.0, maximum=10.0, step=0.05) return [eta, momentum, threshold] def register(self): # register xyz grid elements global registered # pylint: disable=global-statement if registered: return registered = True def apply_field(field): def fun(p, x, xs): # pylint: disable=unused-argument setattr(p, field, x) self.run(p) return fun import sys xyz_classes = [v for k, v in sys.modules.items() if 'xyz_grid_classes' in k] if xyz_classes and len(xyz_classes) > 0: xyz_classes = xyz_classes[0] options = [ xyz_classes.AxisOption("[APG] ETA", float, apply_field("apg_eta")), xyz_classes.AxisOption("[APG] Momentum", float, apply_field("apg_momentum")), xyz_classes.AxisOption("[APG] Threshold", float, apply_field("apg_threshold")), ] for option in options: if option not in xyz_classes.axis_options: xyz_classes.axis_options.append(option) def run(self, p: processing.StableDiffusionProcessing, eta = 0.0, momentum = 0.0, threshold = 0.0): # pylint: disable=arguments-differ supported_model_list = ['sd', 'sdxl', 'sc'] if shared.sd_model_type not in supported_model_list: shared.log.warning(f'APG: class={shared.sd_model.__class__.__name__} model={shared.sd_model_type} required={supported_model_list}') return None from modules import apg apg.eta = getattr(p, 'apg_eta', eta) # use values set by xyz grid or via ui apg.momentum = getattr(p, 'apg_momentum', momentum) apg.threshold = getattr(p, 'apg_threshold', threshold) apg.buffer = apg.MomentumBuffer(apg.momentum) # recreate buffer # pipelines with call to apg.normalized_guidance instead of default if shared.sd_model_type == "sd": self.orig_pipe = shared.sd_model shared.sd_model = sd_models.switch_pipe(apg.StableDiffusionPipelineAPG, shared.sd_model) if shared.sd_model_type == "sdxl": self.orig_pipe = shared.sd_model shared.sd_model = sd_models.switch_pipe(apg.StableDiffusionXLPipelineAPG, shared.sd_model) elif shared.sd_model_type == "sc": self.orig_pipe = shared.sd_model.prior_pipe shared.sd_model.prior_pipe = sd_models.switch_pipe(apg.StableCascadePriorPipelineAPG, shared.sd_model.prior_pipe) shared.log.info(f'APG apply: guidance={p.cfg_scale} momentum={apg.momentum} eta={apg.eta} threshold={apg.threshold} class={shared.sd_model.__class__.__name__}') p.extra_generation_params["APG"] = f'ETA={apg.eta} Momentum={apg.momentum} Threshold={apg.threshold}' # processed = processing.process_images(p) return None def after(self, p: processing.StableDiffusionProcessing, processed: processing.Processed, eta, momentum, threshold): # pylint: disable=arguments-differ, unused-argument from modules import apg if self.orig_pipe is None: return processed # restore pipeline if shared.sd_model_type == "sdxl" or shared.sd_model_type == "sd": shared.sd_model = self.orig_pipe elif shared.sd_model_type == "sc": shared.sd_model.prior_pipe = self.orig_pipe apg.buffer = None self.orig_pipe = None return processed