1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[hybrid inference 🍯🐝] Add VAE encode (#11017)

* [hybrid inference 🍯🐝] Add VAE encode

* _toctree: add vae encode

* Add endpoints, tests

* vae_encode docs

* vae encode benchmarks

* api reference

* changelog

* Update docs/source/en/hybrid_inference/overview.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
hlky
2025-03-12 11:23:41 +00:00
committed by GitHub
parent 8b4f8ba764
commit 733b44ac82
8 changed files with 546 additions and 22 deletions

View File

@@ -81,6 +81,8 @@
title: Overview
- local: hybrid_inference/vae_decode
title: VAE Decode
- local: hybrid_inference/vae_encode
title: VAE Encode
- local: hybrid_inference/api_reference
title: API Reference
title: Hybrid Inference

View File

@@ -3,3 +3,7 @@
## Remote Decode
[[autodoc]] utils.remote_utils.remote_decode
## Remote Encode
[[autodoc]] utils.remote_utils.remote_encode

View File

@@ -36,7 +36,7 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
## Available Models
* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed.
* **VAE Encode 🔢 (coming soon):** Efficiently encode images into latent representations for generation and training.
* **VAE Encode 🔢:** Efficiently encode images into latent representations for generation and training.
* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow.
---
@@ -46,9 +46,15 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
## Changelog
- March 10 2025: Added VAE encode
- March 2 2025: Initial release with VAE decoding
## Contents
The documentation is organized into two sections:
The documentation is organized into three sections:
* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference.
* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference.
* **API Reference** Dive into task-specific settings and parameters.

View File

@@ -0,0 +1,183 @@
# Getting Started: VAE Encode with Hybrid Inference
VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations.
## Memory
These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs.
For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality.
<details><summary>SD v1.5</summary>
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
</details>
<details><summary>SDXL</summary>
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
</details>
## Available VAEs
| | **Endpoint** | **Model** |
|:-:|:-----------:|:--------:|
| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
> [!TIP]
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
## Code
> [!TIP]
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
A helper method simplifies interacting with Hybrid Inference.
```python
from diffusers.utils.remote_utils import remote_encode
```
### Basic example
Let's encode an image, then decode it to demonstrate.
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"/>
</figure>
<details><summary>Code</summary>
```python
from diffusers.utils import load_image
from diffusers.utils.remote_utils import remote_decode
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true")
latent = remote_encode(
endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/",
scaling_factor=0.3611,
shift_factor=0.1159,
)
decoded = remote_decode(
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
scaling_factor=0.3611,
shift_factor=0.1159,
)
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/decoded.png"/>
</figure>
### Generation
Now let's look at a generation example, we'll encode the image, generate then remotely decode too!
<details><summary>Code</summary>
```python
import torch
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers.utils import load_image
from diffusers.utils.remote_utils import remote_decode, remote_encode
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
variant="fp16",
vae=None,
).to("cuda")
init_image = load_image(
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((768, 512))
init_latent = remote_encode(
endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/",
image=init_image,
scaling_factor=0.18215,
)
prompt = "A fantasy landscape, trending on artstation"
latent = pipe(
prompt=prompt,
image=init_latent,
strength=0.75,
output_type="latent",
).images
image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
scaling_factor=0.18215,
)
image.save("fantasy_landscape.jpg")
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/fantasy_landscape.png"/>
</figure>
## Integrations
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.

View File

@@ -56,3 +56,14 @@ USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
if USE_PEFT_BACKEND and _CHECK_PEFT:
dep_version_check("peft")
DECODE_ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
DECODE_ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
DECODE_ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"

View File

@@ -55,7 +55,7 @@ def detect_image_type(data: bytes) -> str:
return "unknown"
def check_inputs(
def check_inputs_decode(
endpoint: str,
tensor: "torch.Tensor",
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
@@ -89,7 +89,7 @@ def check_inputs(
)
def postprocess(
def postprocess_decode(
response: requests.Response,
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
output_type: Literal["mp4", "pil", "pt"] = "pil",
@@ -142,7 +142,7 @@ def postprocess(
return output
def prepare(
def prepare_decode(
tensor: "torch.Tensor",
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
do_scaling: bool = True,
@@ -293,7 +293,7 @@ def remote_decode(
standard_warn=False,
)
output_tensor_type = "binary"
check_inputs(
check_inputs_decode(
endpoint,
tensor,
processor,
@@ -309,7 +309,7 @@ def remote_decode(
height,
width,
)
kwargs = prepare(
kwargs = prepare_decode(
tensor=tensor,
processor=processor,
do_scaling=do_scaling,
@@ -324,7 +324,7 @@ def remote_decode(
response = requests.post(endpoint, **kwargs)
if not response.ok:
raise RuntimeError(response.json())
output = postprocess(
output = postprocess_decode(
response=response,
processor=processor,
output_type=output_type,
@@ -332,3 +332,94 @@ def remote_decode(
partial_postprocess=partial_postprocess,
)
return output
def check_inputs_encode(
endpoint: str,
image: Union["torch.Tensor", Image.Image],
scaling_factor: Optional[float] = None,
shift_factor: Optional[float] = None,
):
pass
def postprocess_encode(
response: requests.Response,
):
output_tensor = response.content
parameters = response.headers
shape = json.loads(parameters["shape"])
dtype = parameters["dtype"]
torch_dtype = DTYPE_MAP[dtype]
output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
return output_tensor
def prepare_encode(
image: Union["torch.Tensor", Image.Image],
scaling_factor: Optional[float] = None,
shift_factor: Optional[float] = None,
):
headers = {}
parameters = {}
if scaling_factor is not None:
parameters["scaling_factor"] = scaling_factor
if shift_factor is not None:
parameters["shift_factor"] = shift_factor
if isinstance(image, torch.Tensor):
data = safetensors.torch._tobytes(image, "tensor")
parameters["shape"] = list(image.shape)
parameters["dtype"] = str(image.dtype).split(".")[-1]
else:
buffer = io.BytesIO()
image.save(buffer, format="PNG")
data = buffer.getvalue()
return {"data": data, "params": parameters, "headers": headers}
def remote_encode(
endpoint: str,
image: Union["torch.Tensor", Image.Image],
scaling_factor: Optional[float] = None,
shift_factor: Optional[float] = None,
) -> "torch.Tensor":
"""
Hugging Face Hybrid Inference that allow running VAE encode remotely.
Args:
endpoint (`str`):
Endpoint for Remote Decode.
image (`torch.Tensor` or `PIL.Image.Image`):
Image to be encoded.
scaling_factor (`float`, *optional*):
Scaling is applied when passed e.g. [`latents * self.vae.config.scaling_factor`].
- SD v1: 0.18215
- SD XL: 0.13025
- Flux: 0.3611
If `None`, input must be passed with scaling applied.
shift_factor (`float`, *optional*):
Shift is applied when passed e.g. `latents - self.vae.config.shift_factor`.
- Flux: 0.1159
If `None`, input must be passed with scaling applied.
Returns:
output (`torch.Tensor`).
"""
check_inputs_encode(
endpoint,
image,
scaling_factor,
shift_factor,
)
kwargs = prepare_encode(
image=image,
scaling_factor=scaling_factor,
shift_factor=shift_factor,
)
response = requests.post(endpoint, **kwargs)
if not response.ok:
raise RuntimeError(response.json())
output = postprocess_encode(
response=response,
)
return output

View File

@@ -21,7 +21,15 @@ import PIL.Image
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils.remote_utils import remote_decode
from diffusers.utils.constants import (
DECODE_ENDPOINT_FLUX,
DECODE_ENDPOINT_HUNYUAN_VIDEO,
DECODE_ENDPOINT_SD_V1,
DECODE_ENDPOINT_SD_XL,
)
from diffusers.utils.remote_utils import (
remote_decode,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
slow,
@@ -33,11 +41,6 @@ from diffusers.video_processor import VideoProcessor
enable_full_determinism()
ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
class RemoteAutoencoderKLMixin:
shape: Tuple[int, ...] = None
@@ -350,7 +353,7 @@ class RemoteAutoencoderKLSDv1Tests(
512,
512,
)
endpoint = ENDPOINT_SD_V1
endpoint = DECODE_ENDPOINT_SD_V1
dtype = torch.float16
scaling_factor = 0.18215
shift_factor = None
@@ -374,7 +377,7 @@ class RemoteAutoencoderKLSDXLTests(
1024,
1024,
)
endpoint = ENDPOINT_SD_XL
endpoint = DECODE_ENDPOINT_SD_XL
dtype = torch.float16
scaling_factor = 0.13025
shift_factor = None
@@ -398,7 +401,7 @@ class RemoteAutoencoderKLFluxTests(
1024,
1024,
)
endpoint = ENDPOINT_FLUX
endpoint = DECODE_ENDPOINT_FLUX
dtype = torch.bfloat16
scaling_factor = 0.3611
shift_factor = 0.1159
@@ -425,7 +428,7 @@ class RemoteAutoencoderKLFluxPackedTests(
)
height = 1024
width = 1024
endpoint = ENDPOINT_FLUX
endpoint = DECODE_ENDPOINT_FLUX
dtype = torch.bfloat16
scaling_factor = 0.3611
shift_factor = 0.1159
@@ -453,7 +456,7 @@ class RemoteAutoencoderKLHunyuanVideoTests(
320,
512,
)
endpoint = ENDPOINT_HUNYUAN_VIDEO
endpoint = DECODE_ENDPOINT_HUNYUAN_VIDEO
dtype = torch.float16
scaling_factor = 0.476986
processor_cls = VideoProcessor
@@ -504,7 +507,7 @@ class RemoteAutoencoderKLSDv1SlowTests(
RemoteAutoencoderKLSlowTestMixin,
unittest.TestCase,
):
endpoint = ENDPOINT_SD_V1
endpoint = DECODE_ENDPOINT_SD_V1
dtype = torch.float16
scaling_factor = 0.18215
shift_factor = None
@@ -515,7 +518,7 @@ class RemoteAutoencoderKLSDXLSlowTests(
RemoteAutoencoderKLSlowTestMixin,
unittest.TestCase,
):
endpoint = ENDPOINT_SD_XL
endpoint = DECODE_ENDPOINT_SD_XL
dtype = torch.float16
scaling_factor = 0.13025
shift_factor = None
@@ -527,7 +530,7 @@ class RemoteAutoencoderKLFluxSlowTests(
unittest.TestCase,
):
channels = 16
endpoint = ENDPOINT_FLUX
endpoint = DECODE_ENDPOINT_FLUX
dtype = torch.bfloat16
scaling_factor = 0.3611
shift_factor = 0.1159

View File

@@ -0,0 +1,224 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import PIL.Image
import torch
from diffusers.utils import load_image
from diffusers.utils.constants import (
DECODE_ENDPOINT_FLUX,
DECODE_ENDPOINT_SD_V1,
DECODE_ENDPOINT_SD_XL,
ENCODE_ENDPOINT_FLUX,
ENCODE_ENDPOINT_SD_V1,
ENCODE_ENDPOINT_SD_XL,
)
from diffusers.utils.remote_utils import (
remote_decode,
remote_encode,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
slow,
)
enable_full_determinism()
IMAGE = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true"
class RemoteAutoencoderKLEncodeMixin:
channels: int = None
endpoint: str = None
decode_endpoint: str = None
dtype: torch.dtype = None
scaling_factor: float = None
shift_factor: float = None
image: PIL.Image.Image = None
def get_dummy_inputs(self):
if self.image is None:
self.image = load_image(IMAGE)
inputs = {
"endpoint": self.endpoint,
"image": self.image,
"scaling_factor": self.scaling_factor,
"shift_factor": self.shift_factor,
}
return inputs
def test_image_input(self):
inputs = self.get_dummy_inputs()
height, width = inputs["image"].height, inputs["image"].width
output = remote_encode(**inputs)
self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8])
decoded = remote_decode(
tensor=output,
endpoint=self.decode_endpoint,
scaling_factor=self.scaling_factor,
shift_factor=self.shift_factor,
image_format="png",
)
self.assertEqual(decoded.height, height)
self.assertEqual(decoded.width, width)
# image_slice = torch.from_numpy(np.array(inputs["image"])[0, -3:, -3:].flatten())
# decoded_slice = torch.from_numpy(np.array(decoded)[0, -3:, -3:].flatten())
# TODO: how to test this? encode->decode is lossy. expected slice of encoded latent?
class RemoteAutoencoderKLSDv1Tests(
RemoteAutoencoderKLEncodeMixin,
unittest.TestCase,
):
channels = 4
endpoint = ENCODE_ENDPOINT_SD_V1
decode_endpoint = DECODE_ENDPOINT_SD_V1
dtype = torch.float16
scaling_factor = 0.18215
shift_factor = None
class RemoteAutoencoderKLSDXLTests(
RemoteAutoencoderKLEncodeMixin,
unittest.TestCase,
):
channels = 4
endpoint = ENCODE_ENDPOINT_SD_XL
decode_endpoint = DECODE_ENDPOINT_SD_XL
dtype = torch.float16
scaling_factor = 0.13025
shift_factor = None
class RemoteAutoencoderKLFluxTests(
RemoteAutoencoderKLEncodeMixin,
unittest.TestCase,
):
channels = 16
endpoint = ENCODE_ENDPOINT_FLUX
decode_endpoint = DECODE_ENDPOINT_FLUX
dtype = torch.bfloat16
scaling_factor = 0.3611
shift_factor = 0.1159
class RemoteAutoencoderKLEncodeSlowTestMixin:
channels: int = 4
endpoint: str = None
decode_endpoint: str = None
dtype: torch.dtype = None
scaling_factor: float = None
shift_factor: float = None
image: PIL.Image.Image = None
def get_dummy_inputs(self):
if self.image is None:
self.image = load_image(IMAGE)
inputs = {
"endpoint": self.endpoint,
"image": self.image,
"scaling_factor": self.scaling_factor,
"shift_factor": self.shift_factor,
}
return inputs
def test_multi_res(self):
inputs = self.get_dummy_inputs()
for height in {
320,
512,
640,
704,
896,
1024,
1208,
1384,
1536,
1608,
1864,
2048,
}:
for width in {
320,
512,
640,
704,
896,
1024,
1208,
1384,
1536,
1608,
1864,
2048,
}:
inputs["image"] = inputs["image"].resize(
(
width,
height,
)
)
output = remote_encode(**inputs)
self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8])
decoded = remote_decode(
tensor=output,
endpoint=self.decode_endpoint,
scaling_factor=self.scaling_factor,
shift_factor=self.shift_factor,
image_format="png",
)
self.assertEqual(decoded.height, height)
self.assertEqual(decoded.width, width)
decoded.save(f"test_multi_res_{height}_{width}.png")
@slow
class RemoteAutoencoderKLSDv1SlowTests(
RemoteAutoencoderKLEncodeSlowTestMixin,
unittest.TestCase,
):
endpoint = ENCODE_ENDPOINT_SD_V1
decode_endpoint = DECODE_ENDPOINT_SD_V1
dtype = torch.float16
scaling_factor = 0.18215
shift_factor = None
@slow
class RemoteAutoencoderKLSDXLSlowTests(
RemoteAutoencoderKLEncodeSlowTestMixin,
unittest.TestCase,
):
endpoint = ENCODE_ENDPOINT_SD_XL
decode_endpoint = DECODE_ENDPOINT_SD_XL
dtype = torch.float16
scaling_factor = 0.13025
shift_factor = None
@slow
class RemoteAutoencoderKLFluxSlowTests(
RemoteAutoencoderKLEncodeSlowTestMixin,
unittest.TestCase,
):
channels = 16
endpoint = ENCODE_ENDPOINT_FLUX
decode_endpoint = DECODE_ENDPOINT_FLUX
dtype = torch.bfloat16
scaling_factor = 0.3611
shift_factor = 0.1159