1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00

zimage-controlnet-union initial code

Signed-off-by: vladmandic <mandic00@live.com>
This commit is contained in:
vladmandic
2025-12-24 10:42:42 +01:00
parent c0141de02d
commit fab224c4df
4 changed files with 36 additions and 3 deletions

View File

@@ -16,6 +16,8 @@
- **Z-Image** support loading transformer file-tunes in safetensors format
as with any transformers/unet finetunes, place them then `models/unet`
and use **UNET Model** to load safetensors file as they are not complete models
- **Z-Image** support for **ControlNet Union**
includes 1.0, 2.0 and 2.1 variants
- **Detailer** support for segmentation models
some detection models can produce exact segmentation mask and not just box
to enable, set `use segmentation` option

View File

@@ -111,6 +111,11 @@ predefined_hunyuandit = {
"HunyuanDiT Pose": 'Tencent-Hunyuan/HunyuanDiT-v1.2-ControlNet-Diffusers-Pose',
"HunyuanDiT Depth": 'Tencent-Hunyuan/HunyuanDiT-v1.2-ControlNet-Diffusers-Depth',
}
predefined_zimage = {
"Z-Image-Turbo Union 1.0": 'hlky/Z-Image-Turbo-Fun-Controlnet-Union',
"Z-Image-Turbo Union 2.0": 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0',
"Z-Image-Turbo Union 2.1": 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1',
}
variants = {
'NoobAI Canny XL': 'fp16',
@@ -137,6 +142,7 @@ all_models.update(predefined_f1)
all_models.update(predefined_sd3)
all_models.update(predefined_qwen)
all_models.update(predefined_hunyuandit)
all_models.update(predefined_zimage)
cache_dir = 'models/control/controlnet'
load_lock = threading.Lock()
@@ -175,6 +181,8 @@ def api_list_models(model_type: str = None):
model_list += list(predefined_qwen)
if model_type == 'hunyuandit' or model_type == 'all':
model_list += list(predefined_hunyuandit)
if model_type == 'z_image':
model_list == list(predefined_zimage)
model_list += sorted(find_models())
return model_list
@@ -199,9 +207,11 @@ def list_models(refresh=False):
models = ['None'] + list(predefined_qwen) + sorted(find_models())
elif modules.shared.sd_model_type == 'hunyuandit':
models = ['None'] + list(predefined_hunyuandit) + sorted(find_models())
elif modules.shared.sd_model_type == 'z_image':
models = ['None'] + list(predefined_zimage) + sorted(find_models())
else:
log.warning(f'Control {what} model list failed: unknown model type')
models = ['None'] + sorted(predefined_sd15) + sorted(predefined_sdxl) + sorted(predefined_f1) + sorted(predefined_sd3) + sorted(find_models())
models = ['None'] + list(all_models) + sorted(find_models())
debug_log(f'Control list {what}: path={cache_dir} models={models}')
return models
@@ -263,6 +273,14 @@ class ControlNet():
elif shared.sd_model_type == 'hunyuandit':
from diffusers import HunyuanDiT2DControlNetModel as cls
config = 'Tencent-Hunyuan/HunyuanDiT-v1.2-ControlNet-Diffusers-Canny'
elif shared.sd_model_type == 'z_image':
from diffusers import ZImageControlNetModel as cls
if '2.0' in model_id:
config = 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0'
elif '2.1' in model_id:
config = 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1'
else:
config = 'hlky/Z-Image-Turbo-Fun-Controlnet-Union'
else:
log.error(f'Control {what}: type={shared.sd_model_type} unsupported model')
return None, None
@@ -508,6 +526,16 @@ class ControlNetPipeline():
feature_extractor=None,
controlnet=controlnets[0] if isinstance(controlnets, list) else controlnets, # can be a list
)
elif detect.is_zimage(pipeline) and len(controlnets) > 0:
from diffusers import ZImageControlNetPipeline
self.pipeline = ZImageControlNetPipeline(
vae=pipeline.vae,
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
transformer=pipeline.transformer,
scheduler=pipeline.scheduler,
controlnet=controlnets[0] if isinstance(controlnets, list) else controlnets, # can be a list
)
elif len(loras) > 0:
self.pipeline = pipeline
for lora in loras:

View File

@@ -28,3 +28,6 @@ def is_qwen(model):
def is_hunyuandit(model):
return is_compatible(model, pattern='HunyuanDiT')
def is_zimage(model):
return is_compatible(model, pattern='ZImage')

View File

@@ -156,8 +156,8 @@ def decode(latents):
image = vae.decode(tensor, return_dict=False)[0]
image = (image / 2.0 + 0.5).clamp(0, 1).detach()
t1 = time.time()
if (t1 - t0) > 1.0 and not first_run:
shared.log.warning(f'Decode: type="taesd" variant="{variant}" time{t1 - t0:.2f}')
if (t1 - t0) > 3.0 and not first_run:
shared.log.warning(f'Decode: type="taesd" variant="{variant}" long decode time={t1 - t0:.2f}')
first_run = False
return image
except Exception as e: