diff --git a/CHANGELOG.md b/CHANGELOG.md index 31c6fead0..5fc27ab65 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ - **Z-Image** support loading transformer file-tunes in safetensors format as with any transformers/unet finetunes, place them then `models/unet` and use **UNET Model** to load safetensors file as they are not complete models + - **Z-Image** support for **ControlNet Union** + includes 1.0, 2.0 and 2.1 variants - **Detailer** support for segmentation models some detection models can produce exact segmentation mask and not just box to enable, set `use segmentation` option diff --git a/modules/control/units/controlnet.py b/modules/control/units/controlnet.py index 571491987..e5db65221 100644 --- a/modules/control/units/controlnet.py +++ b/modules/control/units/controlnet.py @@ -111,6 +111,11 @@ predefined_hunyuandit = { "HunyuanDiT Pose": 'Tencent-Hunyuan/HunyuanDiT-v1.2-ControlNet-Diffusers-Pose', "HunyuanDiT Depth": 'Tencent-Hunyuan/HunyuanDiT-v1.2-ControlNet-Diffusers-Depth', } +predefined_zimage = { + "Z-Image-Turbo Union 1.0": 'hlky/Z-Image-Turbo-Fun-Controlnet-Union', + "Z-Image-Turbo Union 2.0": 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0', + "Z-Image-Turbo Union 2.1": 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1', +} variants = { 'NoobAI Canny XL': 'fp16', @@ -137,6 +142,7 @@ all_models.update(predefined_f1) all_models.update(predefined_sd3) all_models.update(predefined_qwen) all_models.update(predefined_hunyuandit) +all_models.update(predefined_zimage) cache_dir = 'models/control/controlnet' load_lock = threading.Lock() @@ -175,6 +181,8 @@ def api_list_models(model_type: str = None): model_list += list(predefined_qwen) if model_type == 'hunyuandit' or model_type == 'all': model_list += list(predefined_hunyuandit) + if model_type == 'z_image': + model_list == list(predefined_zimage) model_list += sorted(find_models()) return model_list @@ -199,9 +207,11 @@ def list_models(refresh=False): models = ['None'] + list(predefined_qwen) + sorted(find_models()) elif modules.shared.sd_model_type == 'hunyuandit': models = ['None'] + list(predefined_hunyuandit) + sorted(find_models()) + elif modules.shared.sd_model_type == 'z_image': + models = ['None'] + list(predefined_zimage) + sorted(find_models()) else: log.warning(f'Control {what} model list failed: unknown model type') - models = ['None'] + sorted(predefined_sd15) + sorted(predefined_sdxl) + sorted(predefined_f1) + sorted(predefined_sd3) + sorted(find_models()) + models = ['None'] + list(all_models) + sorted(find_models()) debug_log(f'Control list {what}: path={cache_dir} models={models}') return models @@ -263,6 +273,14 @@ class ControlNet(): elif shared.sd_model_type == 'hunyuandit': from diffusers import HunyuanDiT2DControlNetModel as cls config = 'Tencent-Hunyuan/HunyuanDiT-v1.2-ControlNet-Diffusers-Canny' + elif shared.sd_model_type == 'z_image': + from diffusers import ZImageControlNetModel as cls + if '2.0' in model_id: + config = 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0' + elif '2.1' in model_id: + config = 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1' + else: + config = 'hlky/Z-Image-Turbo-Fun-Controlnet-Union' else: log.error(f'Control {what}: type={shared.sd_model_type} unsupported model') return None, None @@ -508,6 +526,16 @@ class ControlNetPipeline(): feature_extractor=None, controlnet=controlnets[0] if isinstance(controlnets, list) else controlnets, # can be a list ) + elif detect.is_zimage(pipeline) and len(controlnets) > 0: + from diffusers import ZImageControlNetPipeline + self.pipeline = ZImageControlNetPipeline( + vae=pipeline.vae, + text_encoder=pipeline.text_encoder, + tokenizer=pipeline.tokenizer, + transformer=pipeline.transformer, + scheduler=pipeline.scheduler, + controlnet=controlnets[0] if isinstance(controlnets, list) else controlnets, # can be a list + ) elif len(loras) > 0: self.pipeline = pipeline for lora in loras: diff --git a/modules/control/units/detect.py b/modules/control/units/detect.py index 8d836015f..83be1093f 100644 --- a/modules/control/units/detect.py +++ b/modules/control/units/detect.py @@ -28,3 +28,6 @@ def is_qwen(model): def is_hunyuandit(model): return is_compatible(model, pattern='HunyuanDiT') + +def is_zimage(model): + return is_compatible(model, pattern='ZImage') diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 1ec5b15cb..00294d560 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -156,8 +156,8 @@ def decode(latents): image = vae.decode(tensor, return_dict=False)[0] image = (image / 2.0 + 0.5).clamp(0, 1).detach() t1 = time.time() - if (t1 - t0) > 1.0 and not first_run: - shared.log.warning(f'Decode: type="taesd" variant="{variant}" time{t1 - t0:.2f}') + if (t1 - t0) > 3.0 and not first_run: + shared.log.warning(f'Decode: type="taesd" variant="{variant}" long decode time={t1 - t0:.2f}') first_run = False return image except Exception as e: