import gradio as gr
import torch
import diffusers
from huggingface_hub import hf_hub_download
from modules import scripts_manager, processing, shared, sd_models, devices, ipadapter
class Script(scripts_manager.Script):
def __init__(self):
super().__init__()
self.orig_pipe = None
self.orig_ip_unapply = None
def title(self):
return 'InstantIR: Image Restoration'
def show(self, is_img2img):
return is_img2img
def ui(self, _is_img2img): # ui elements
with gr.Row():
gr.HTML('  InstantIR: Image Restoration
')
with gr.Row():
start = gr.Slider(label='Preview start', minimum=0.0, maximum=1.0, step=0.01, value=0.0)
end = gr.Slider(label='Preview end', minimum=0.0, maximum=1.0, step=0.01, value=1.0)
with gr.Row():
hq = gr.Checkbox(label='HQ init latents', value=False)
unload = gr.Checkbox(label='Unload after processing', value=False, visible=False)
with gr.Row():
multistep = gr.Checkbox(label='Multistep restore', value=False)
adastep = gr.Checkbox(label='Adaptive restore', value=False)
with gr.Row():
image = gr.Image(label='Override guidance image')
return [start, end, hq, multistep, adastep, image, unload]
def run(self, p: processing.StableDiffusionProcessing, *args): # pylint: disable=arguments-differ
supported_model_list = ['sdxl']
if not hasattr(p, 'init_images') or len(p.init_images) == 0:
shared.log.warning('InstantIR: no image')
return None
if shared.sd_model_type not in supported_model_list and shared.sd_model.__class__.__name__ != "InstantIRPipeline":
shared.log.warning(f'InstantIR: class={shared.sd_model.__class__.__name__} model={shared.sd_model_type} required={supported_model_list}')
return None
start, end, hq, multistep, adastep, image, _unload = args
from scripts import instantir
if shared.sd_model_type == "sdxl":
if shared.sd_model.__class__.__name__ != "InstantIRPipeline":
self.orig_pipe = shared.sd_model
self.orig_ip_unapply = ipadapter.unapply
adapter_file = hf_hub_download('InstantX/InstantIR', subfolder='models', filename='adapter.pt', cache_dir=shared.opts.hfcache_dir)
aggregator_file = hf_hub_download('InstantX/InstantIR', subfolder='models', filename='aggregator.pt', cache_dir=shared.opts.hfcache_dir)
previewer_file = hf_hub_download('InstantX/InstantIR', subfolder='models', filename='previewer_lora_weights.bin', cache_dir=shared.opts.hfcache_dir)
shared.log.debug(f'InstantIR: adapter="{adapter_file}" aggregator="{aggregator_file}" previewer="{previewer_file}"')
shared.sd_model = sd_models.switch_pipe(instantir.InstantIRPipeline, shared.sd_model)
instantir.load_adapter_to_pipe(
pipe=shared.sd_model,
pretrained_model_path_or_dict=adapter_file,
image_encoder_or_path='facebook/dinov2-large',
use_lcm=False,
use_adaln=True,
)
shared.sd_model.prepare_previewers(previewer_file)
shared.sd_model.scheduler = diffusers.DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler")
pretrained_state_dict = torch.load(aggregator_file)
shared.sd_model.aggregator.load_state_dict(pretrained_state_dict)
shared.sd_model.aggregator.to(device=devices.device, dtype=devices.dtype)
ipadapter.unapply = self.dummy_unapply # disable as main processing unloads ipadapter as it thinks its not needed
sd_models.clear_caches()
sd_models.apply_balanced_offload(shared.sd_model)
shared.log.info(f'InstantIR: class={shared.sd_model.__class__.__name__} start={start} end={end} multistep={multistep} adastep={adastep} hq={hq} cache={shared.opts.hfcache_dir}')
p.sampler_name = 'Default' # ir has its own sampler
p.init() # run init early to take care of resizing
p.task_args['previewer_scheduler'] = instantir.LCMSingleStepScheduler.from_config(shared.sd_model.scheduler.config)
p.task_args['image'] = p.init_images
p.task_args['save_preview_row'] = False
p.task_args['init_latents_with_lq'] = not hq
p.task_args['multistep_restore'] = multistep
p.task_args['adastep_restore'] = adastep
p.task_args['preview_start'] = start
p.task_args['preview_end'] = end
p.task_args['ip_adapter_image'] = image
p.extra_generation_params["InstantIR"] = f'Start={start} End={end} HQ={hq} Multistep={multistep} Adastep={adastep}'
devices.torch_gc()
def dummy_unapply(self, pipe, unload): # pylint: disable=unused-argument
pass
def after(self, p: processing.StableDiffusionProcessing, processed: processing.Processed, *args): # pylint: disable=arguments-differ, unused-argument
_start, _end, _hq, _multistep, _adastep, _image, unload = args
if unload:
shared.log.info('InstantIR: unloading adapter')
if self.orig_ip_unapply is not None:
ipadapter.unapply = self.orig_ip_unapply
self.orig_ip_unapply = None
ipadapter.unapply(shared.sd_model)
if hasattr(shared.sd_model, 'aggregator'):
shared.sd_model.aggregator = None
if self.orig_pipe is not None:
shared.sd_model = self.orig_pipe
self.orig_pipe = None
shared.sd_model.unet.register_to_config(encoder_hid_dim_type=None)
sd_models.apply_balanced_offload(shared.sd_model)
shared.log.debug(f'InstantIR restore: class={shared.sd_model.__class__.__name__}')
devices.torch_gc()
return processed