1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/sd_vae_remote.py
vladmandic 71d3482168 cleanup model types
Signed-off-by: vladmandic <mandic00@live.com>
2026-01-15 08:30:48 +01:00

167 lines
7.5 KiB
Python

import io
import time
import json
import torch
import requests
from PIL import Image
from safetensors.torch import _tobytes
hf_decode_endpoints = {
'sd': 'https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud',
'sdxl': 'https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud',
'f1': 'https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud',
'hunyuanvideo': 'https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud',
}
hf_decode_endpoints['pixartalpha'] = hf_decode_endpoints['sd']
hf_decode_endpoints['pixartsigma'] = hf_decode_endpoints['sdxl']
hf_decode_endpoints['hunyuandit'] = hf_decode_endpoints['sdxl']
hf_decode_endpoints['auraflow'] = hf_decode_endpoints['sdxl']
hf_decode_endpoints['omnigen'] = hf_decode_endpoints['sdxl']
hf_decode_endpoints['h1'] = hf_decode_endpoints['f1']
hf_decode_endpoints['chroma'] = hf_decode_endpoints['f1']
hf_decode_endpoints['zimage'] = hf_decode_endpoints['f1']
hf_decode_endpoints['lumina2'] = hf_decode_endpoints['f1']
hf_encode_endpoints = {
'sd': 'https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud',
'sdxl': 'https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud',
'f1': 'https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud',
'chroma': 'https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud',
}
hf_encode_endpoints['pixartalpha'] = hf_encode_endpoints['sd']
hf_encode_endpoints['pixartsigma'] = hf_encode_endpoints['sdxl']
hf_encode_endpoints['hunyuandit'] = hf_encode_endpoints['sdxl']
hf_encode_endpoints['auraflow'] = hf_encode_endpoints['sdxl']
hf_encode_endpoints['omnigen'] = hf_encode_endpoints['sdxl']
hf_encode_endpoints['h1'] = hf_encode_endpoints['f1']
hf_encode_endpoints['zimage'] = hf_encode_endpoints['f1']
hf_encode_endpoints['lumina2'] = hf_encode_endpoints['f1']
dtypes = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"uint8": torch.uint8,
}
def remote_decode(latents: torch.Tensor, width: int = 0, height: int = 0, model_type: str | None = None):
from modules import devices, shared, errors, modelloader
tensors = []
content = 0
model_type = model_type or shared.sd_model_type
url = hf_decode_endpoints.get(model_type, None)
if url is None:
shared.log.error(f'Decode: type="remote" type={model_type} unsuppported')
return tensors
t0 = time.time()
modelloader.hf_login()
latent_copy = latents.detach().clone().to(device=devices.cpu, dtype=devices.dtype)
latent_copy = latents.unsqueeze(0) if len(latents.shape) == 3 else latents
if model_type == 'hunyuanvideo':
latent_copy = latent_copy.unsqueeze(0) if len(latents.shape) == 4 else latents
for i in range(latent_copy.shape[0]):
params = {}
try:
latent = latent_copy[i]
if model_type not in ['f1', 'chroma']:
latent = latent.unsqueeze(0)
params = {
"input_tensor_type": "binary",
"shape": list(latent.shape),
"dtype": str(latent.dtype).split(".", maxsplit=1)[-1],
}
headers = { "Content-Type": "tensor/binary" }
if 'video' in model_type:
params["partial_postprocess"] = False
params["output_type"] = "pt"
params["output_tensor_type"] = "binary"
headers["Accept"] = "tensor/binary"
elif shared.opts.remote_vae_type == 'png':
params["image_format"] = "png"
params["output_type"] = "pil"
headers["Accept"] = "image/png"
elif shared.opts.remote_vae_type == 'jpg':
params["image_format"] = "jpg"
params["output_type"] = "pil"
headers["Accept"] = "image/jpeg"
elif shared.opts.remote_vae_type == 'raw':
params["partial_postprocess"] = False
params["output_type"] = "pt"
params["output_tensor_type"] = "binary"
headers["Accept"] = "tensor/binary"
if model_type in {'f1', 'h1', 'zimage', 'lumina2', 'chroma'} and (width > 0) and (height > 0):
params['width'] = width
params['height'] = height
if shared.sd_model.vae is not None and shared.sd_model.vae.config is not None:
params['scaling_factor'] = shared.sd_model.vae.config.get("scaling_factor", None)
params['shift_factor'] = shared.sd_model.vae.config.get("shift_factor", None)
response = requests.post(
url=url,
headers=headers,
params=params,
data=_tobytes(latent, "tensor"),
timeout=300,
)
if not response.ok:
shared.log.error(f'Decode: type="remote" model={model_type} code={response.status_code} shape={latent.shape} url="{url}" args={params} headers={response.headers} response={response.json()}')
else:
content += len(response.content)
if shared.opts.remote_vae_type == 'raw' or 'video' in model_type:
shape = json.loads(response.headers["shape"])
dtype = response.headers["dtype"]
tensor = torch.frombuffer(bytearray(response.content), dtype=dtypes[dtype]).reshape(shape)
tensors.append(tensor)
elif shared.opts.remote_vae_type == 'jpg' or shared.opts.remote_vae_type == 'png':
image = Image.open(io.BytesIO(response.content)).convert("RGB")
tensors.append(image)
except Exception as e:
shared.log.error(f'Decode: type="remote" model={model_type} {e}')
errors.display(e, 'VAE')
if len(tensors) > 0 and shared.opts.remote_vae_type == 'raw':
tensors = torch.cat(tensors, dim=0)
t1 = time.time()
shared.log.debug(f'Decode: type="remote" model={model_type} mode={shared.opts.remote_vae_type} args={params} bytes={content} time={t1-t0:.3f}s')
return tensors
def remote_encode(images: list[Image.Image], model_type: str | None = None):
from diffusers.utils import remote_utils
from modules import devices, shared, errors, modelloader
if not shared.opts.remote_vae_encode:
return images
tensors = []
model_type = model_type or shared.sd_model_type
url = hf_encode_endpoints.get(model_type, None)
if url is None:
shared.log.error(f'Decode: type="remote" type={model_type} unsuppported')
return images
t0 = time.time()
modelloader.hf_login()
if isinstance(images, Image.Image):
images = [images]
for init_image in images:
try:
init_latent = remote_utils.remote_encode(
endpoint=url,
image=init_image,
scaling_factor = shared.sd_model.vae.config.get("scaling_factor", None),
shift_factor = shared.sd_model.vae.config.get("shift_factor", None),
)
tensors.append(init_latent)
except Exception as e:
shared.log.error(f'Encode: type="remote" model={model_type} {e}')
errors.display(e, 'VAE')
if len(tensors) > 0 and torch.is_tensor(tensors[0]):
tensors = torch.cat(tensors, dim=0)
tensors = tensors.to(dtype=devices.dtype)
else:
return images
t1 = time.time()
shared.log.debug(f'Encode: type="remote" model={model_type} mode={shared.opts.remote_vae_type} image={images} latent={tensors.shape} time={t1-t0:.3f}s')
return tensors