mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
67 lines
2.7 KiB
Python
67 lines
2.7 KiB
Python
from typing import Union
|
|
import time
|
|
import diffusers.utils
|
|
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
|
from modules.shared import log, opts
|
|
from modules.control.units import detect
|
|
from modules import sd_models
|
|
|
|
|
|
what = 'Reference'
|
|
|
|
|
|
def list_models():
|
|
return ['Reference']
|
|
|
|
|
|
class ReferencePipeline():
|
|
def __init__(self, pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None):
|
|
t0 = time.time()
|
|
self.orig_pipeline = pipeline
|
|
self.pipeline = None
|
|
if pipeline is None:
|
|
log.error(f'Control {what} model pipeline: model not loaded')
|
|
return
|
|
if opts.diffusers_fuse_projections and hasattr(pipeline, 'unfuse_qkv_projections'):
|
|
pipeline.unfuse_qkv_projections()
|
|
if detect.is_sdxl(pipeline):
|
|
cls = diffusers.utils.get_class_from_dynamic_module('stable_diffusion_xl_reference', module_file='pipeline.py')
|
|
self.pipeline = cls(
|
|
vae=pipeline.vae,
|
|
text_encoder=pipeline.text_encoder,
|
|
text_encoder_2=pipeline.text_encoder_2,
|
|
tokenizer=pipeline.tokenizer,
|
|
tokenizer_2=pipeline.tokenizer_2,
|
|
unet=pipeline.unet,
|
|
scheduler=pipeline.scheduler,
|
|
feature_extractor=getattr(pipeline, 'feature_extractor', None),
|
|
)
|
|
sd_models.move_model(self.pipeline, pipeline.device)
|
|
elif detect.is_sd15(pipeline):
|
|
cls = diffusers.utils.get_class_from_dynamic_module('stable_diffusion_reference', module_file='pipeline.py')
|
|
self.pipeline = cls(
|
|
vae=pipeline.vae,
|
|
text_encoder=pipeline.text_encoder,
|
|
tokenizer=pipeline.tokenizer,
|
|
unet=pipeline.unet,
|
|
scheduler=pipeline.scheduler,
|
|
feature_extractor=getattr(pipeline, 'feature_extractor', None),
|
|
requires_safety_checker=False,
|
|
safety_checker=None,
|
|
)
|
|
sd_models.move_model(self.pipeline, pipeline.device)
|
|
else:
|
|
log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type')
|
|
return
|
|
if dtype is not None and self.pipeline is not None:
|
|
self.pipeline = self.pipeline.to(dtype)
|
|
t1 = time.time()
|
|
if self.pipeline is not None:
|
|
log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}')
|
|
else:
|
|
log.error(f'Control {what} pipeline: not initialized')
|
|
|
|
def restore(self):
|
|
self.pipeline = None
|
|
return self.orig_pipeline
|