mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[hybrid inference 🍯🐝] Add VAE encode (#11017)
* [hybrid inference 🍯🐝] Add VAE encode * _toctree: add vae encode * Add endpoints, tests * vae_encode docs * vae encode benchmarks * api reference * changelog * Update docs/source/en/hybrid_inference/overview.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -81,6 +81,8 @@
|
||||
title: Overview
|
||||
- local: hybrid_inference/vae_decode
|
||||
title: VAE Decode
|
||||
- local: hybrid_inference/vae_encode
|
||||
title: VAE Encode
|
||||
- local: hybrid_inference/api_reference
|
||||
title: API Reference
|
||||
title: Hybrid Inference
|
||||
|
||||
@@ -3,3 +3,7 @@
|
||||
## Remote Decode
|
||||
|
||||
[[autodoc]] utils.remote_utils.remote_decode
|
||||
|
||||
## Remote Encode
|
||||
|
||||
[[autodoc]] utils.remote_utils.remote_encode
|
||||
|
||||
@@ -36,7 +36,7 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
|
||||
## Available Models
|
||||
|
||||
* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed.
|
||||
* **VAE Encode 🔢 (coming soon):** Efficiently encode images into latent representations for generation and training.
|
||||
* **VAE Encode 🔢:** Efficiently encode images into latent representations for generation and training.
|
||||
* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow.
|
||||
|
||||
---
|
||||
@@ -46,9 +46,15 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
|
||||
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
|
||||
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
|
||||
|
||||
## Changelog
|
||||
|
||||
- March 10 2025: Added VAE encode
|
||||
- March 2 2025: Initial release with VAE decoding
|
||||
|
||||
## Contents
|
||||
|
||||
The documentation is organized into two sections:
|
||||
The documentation is organized into three sections:
|
||||
|
||||
* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference.
|
||||
* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference.
|
||||
* **API Reference** Dive into task-specific settings and parameters.
|
||||
|
||||
183
docs/source/en/hybrid_inference/vae_encode.md
Normal file
183
docs/source/en/hybrid_inference/vae_encode.md
Normal file
@@ -0,0 +1,183 @@
|
||||
# Getting Started: VAE Encode with Hybrid Inference
|
||||
|
||||
VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations.
|
||||
|
||||
## Memory
|
||||
|
||||
These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs.
|
||||
|
||||
For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality.
|
||||
|
||||
<details><summary>SD v1.5</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|
||||
|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
|
||||
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
|
||||
| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
|
||||
| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
|
||||
| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
|
||||
| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>SDXL</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|
||||
|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
|
||||
| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
|
||||
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
|
||||
| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
|
||||
| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
|
||||
| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
|
||||
| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
|
||||
| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
|
||||
| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
|
||||
|
||||
</details>
|
||||
|
||||
## Available VAEs
|
||||
|
||||
| | **Endpoint** | **Model** |
|
||||
|:-:|:-----------:|:--------:|
|
||||
| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
|
||||
| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
|
||||
| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
|
||||
|
||||
|
||||
> [!TIP]
|
||||
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
|
||||
|
||||
|
||||
## Code
|
||||
|
||||
> [!TIP]
|
||||
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
|
||||
|
||||
|
||||
A helper method simplifies interacting with Hybrid Inference.
|
||||
|
||||
```python
|
||||
from diffusers.utils.remote_utils import remote_encode
|
||||
```
|
||||
|
||||
### Basic example
|
||||
|
||||
Let's encode an image, then decode it to demonstrate.
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"/>
|
||||
</figure>
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.remote_utils import remote_decode
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true")
|
||||
|
||||
latent = remote_encode(
|
||||
endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
|
||||
decoded = remote_decode(
|
||||
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/decoded.png"/>
|
||||
</figure>
|
||||
|
||||
|
||||
### Generation
|
||||
|
||||
Now let's look at a generation example, we'll encode the image, generate then remotely decode too!
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.remote_utils import remote_decode, remote_encode
|
||||
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
vae=None,
|
||||
).to("cuda")
|
||||
|
||||
init_image = load_image(
|
||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
|
||||
init_latent = remote_encode(
|
||||
endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
image=init_image,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
latent = pipe(
|
||||
prompt=prompt,
|
||||
image=init_latent,
|
||||
strength=0.75,
|
||||
output_type="latent",
|
||||
).images
|
||||
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
image.save("fantasy_landscape.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/fantasy_landscape.png"/>
|
||||
</figure>
|
||||
|
||||
## Integrations
|
||||
|
||||
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
|
||||
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
|
||||
@@ -56,3 +56,14 @@ USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
|
||||
|
||||
if USE_PEFT_BACKEND and _CHECK_PEFT:
|
||||
dep_version_check("peft")
|
||||
|
||||
|
||||
DECODE_ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
DECODE_ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
DECODE_ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
|
||||
|
||||
ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
|
||||
@@ -55,7 +55,7 @@ def detect_image_type(data: bytes) -> str:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def check_inputs(
|
||||
def check_inputs_decode(
|
||||
endpoint: str,
|
||||
tensor: "torch.Tensor",
|
||||
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
|
||||
@@ -89,7 +89,7 @@ def check_inputs(
|
||||
)
|
||||
|
||||
|
||||
def postprocess(
|
||||
def postprocess_decode(
|
||||
response: requests.Response,
|
||||
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
|
||||
output_type: Literal["mp4", "pil", "pt"] = "pil",
|
||||
@@ -142,7 +142,7 @@ def postprocess(
|
||||
return output
|
||||
|
||||
|
||||
def prepare(
|
||||
def prepare_decode(
|
||||
tensor: "torch.Tensor",
|
||||
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
|
||||
do_scaling: bool = True,
|
||||
@@ -293,7 +293,7 @@ def remote_decode(
|
||||
standard_warn=False,
|
||||
)
|
||||
output_tensor_type = "binary"
|
||||
check_inputs(
|
||||
check_inputs_decode(
|
||||
endpoint,
|
||||
tensor,
|
||||
processor,
|
||||
@@ -309,7 +309,7 @@ def remote_decode(
|
||||
height,
|
||||
width,
|
||||
)
|
||||
kwargs = prepare(
|
||||
kwargs = prepare_decode(
|
||||
tensor=tensor,
|
||||
processor=processor,
|
||||
do_scaling=do_scaling,
|
||||
@@ -324,7 +324,7 @@ def remote_decode(
|
||||
response = requests.post(endpoint, **kwargs)
|
||||
if not response.ok:
|
||||
raise RuntimeError(response.json())
|
||||
output = postprocess(
|
||||
output = postprocess_decode(
|
||||
response=response,
|
||||
processor=processor,
|
||||
output_type=output_type,
|
||||
@@ -332,3 +332,94 @@ def remote_decode(
|
||||
partial_postprocess=partial_postprocess,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def check_inputs_encode(
|
||||
endpoint: str,
|
||||
image: Union["torch.Tensor", Image.Image],
|
||||
scaling_factor: Optional[float] = None,
|
||||
shift_factor: Optional[float] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def postprocess_encode(
|
||||
response: requests.Response,
|
||||
):
|
||||
output_tensor = response.content
|
||||
parameters = response.headers
|
||||
shape = json.loads(parameters["shape"])
|
||||
dtype = parameters["dtype"]
|
||||
torch_dtype = DTYPE_MAP[dtype]
|
||||
output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
|
||||
return output_tensor
|
||||
|
||||
|
||||
def prepare_encode(
|
||||
image: Union["torch.Tensor", Image.Image],
|
||||
scaling_factor: Optional[float] = None,
|
||||
shift_factor: Optional[float] = None,
|
||||
):
|
||||
headers = {}
|
||||
parameters = {}
|
||||
if scaling_factor is not None:
|
||||
parameters["scaling_factor"] = scaling_factor
|
||||
if shift_factor is not None:
|
||||
parameters["shift_factor"] = shift_factor
|
||||
if isinstance(image, torch.Tensor):
|
||||
data = safetensors.torch._tobytes(image, "tensor")
|
||||
parameters["shape"] = list(image.shape)
|
||||
parameters["dtype"] = str(image.dtype).split(".")[-1]
|
||||
else:
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
data = buffer.getvalue()
|
||||
return {"data": data, "params": parameters, "headers": headers}
|
||||
|
||||
|
||||
def remote_encode(
|
||||
endpoint: str,
|
||||
image: Union["torch.Tensor", Image.Image],
|
||||
scaling_factor: Optional[float] = None,
|
||||
shift_factor: Optional[float] = None,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Hugging Face Hybrid Inference that allow running VAE encode remotely.
|
||||
|
||||
Args:
|
||||
endpoint (`str`):
|
||||
Endpoint for Remote Decode.
|
||||
image (`torch.Tensor` or `PIL.Image.Image`):
|
||||
Image to be encoded.
|
||||
scaling_factor (`float`, *optional*):
|
||||
Scaling is applied when passed e.g. [`latents * self.vae.config.scaling_factor`].
|
||||
- SD v1: 0.18215
|
||||
- SD XL: 0.13025
|
||||
- Flux: 0.3611
|
||||
If `None`, input must be passed with scaling applied.
|
||||
shift_factor (`float`, *optional*):
|
||||
Shift is applied when passed e.g. `latents - self.vae.config.shift_factor`.
|
||||
- Flux: 0.1159
|
||||
If `None`, input must be passed with scaling applied.
|
||||
|
||||
Returns:
|
||||
output (`torch.Tensor`).
|
||||
"""
|
||||
check_inputs_encode(
|
||||
endpoint,
|
||||
image,
|
||||
scaling_factor,
|
||||
shift_factor,
|
||||
)
|
||||
kwargs = prepare_encode(
|
||||
image=image,
|
||||
scaling_factor=scaling_factor,
|
||||
shift_factor=shift_factor,
|
||||
)
|
||||
response = requests.post(endpoint, **kwargs)
|
||||
if not response.ok:
|
||||
raise RuntimeError(response.json())
|
||||
output = postprocess_encode(
|
||||
response=response,
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -21,7 +21,15 @@ import PIL.Image
|
||||
import torch
|
||||
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.utils.remote_utils import remote_decode
|
||||
from diffusers.utils.constants import (
|
||||
DECODE_ENDPOINT_FLUX,
|
||||
DECODE_ENDPOINT_HUNYUAN_VIDEO,
|
||||
DECODE_ENDPOINT_SD_V1,
|
||||
DECODE_ENDPOINT_SD_XL,
|
||||
)
|
||||
from diffusers.utils.remote_utils import (
|
||||
remote_decode,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
slow,
|
||||
@@ -33,11 +41,6 @@ from diffusers.video_processor import VideoProcessor
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
|
||||
|
||||
class RemoteAutoencoderKLMixin:
|
||||
shape: Tuple[int, ...] = None
|
||||
@@ -350,7 +353,7 @@ class RemoteAutoencoderKLSDv1Tests(
|
||||
512,
|
||||
512,
|
||||
)
|
||||
endpoint = ENDPOINT_SD_V1
|
||||
endpoint = DECODE_ENDPOINT_SD_V1
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.18215
|
||||
shift_factor = None
|
||||
@@ -374,7 +377,7 @@ class RemoteAutoencoderKLSDXLTests(
|
||||
1024,
|
||||
1024,
|
||||
)
|
||||
endpoint = ENDPOINT_SD_XL
|
||||
endpoint = DECODE_ENDPOINT_SD_XL
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.13025
|
||||
shift_factor = None
|
||||
@@ -398,7 +401,7 @@ class RemoteAutoencoderKLFluxTests(
|
||||
1024,
|
||||
1024,
|
||||
)
|
||||
endpoint = ENDPOINT_FLUX
|
||||
endpoint = DECODE_ENDPOINT_FLUX
|
||||
dtype = torch.bfloat16
|
||||
scaling_factor = 0.3611
|
||||
shift_factor = 0.1159
|
||||
@@ -425,7 +428,7 @@ class RemoteAutoencoderKLFluxPackedTests(
|
||||
)
|
||||
height = 1024
|
||||
width = 1024
|
||||
endpoint = ENDPOINT_FLUX
|
||||
endpoint = DECODE_ENDPOINT_FLUX
|
||||
dtype = torch.bfloat16
|
||||
scaling_factor = 0.3611
|
||||
shift_factor = 0.1159
|
||||
@@ -453,7 +456,7 @@ class RemoteAutoencoderKLHunyuanVideoTests(
|
||||
320,
|
||||
512,
|
||||
)
|
||||
endpoint = ENDPOINT_HUNYUAN_VIDEO
|
||||
endpoint = DECODE_ENDPOINT_HUNYUAN_VIDEO
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.476986
|
||||
processor_cls = VideoProcessor
|
||||
@@ -504,7 +507,7 @@ class RemoteAutoencoderKLSDv1SlowTests(
|
||||
RemoteAutoencoderKLSlowTestMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
endpoint = ENDPOINT_SD_V1
|
||||
endpoint = DECODE_ENDPOINT_SD_V1
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.18215
|
||||
shift_factor = None
|
||||
@@ -515,7 +518,7 @@ class RemoteAutoencoderKLSDXLSlowTests(
|
||||
RemoteAutoencoderKLSlowTestMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
endpoint = ENDPOINT_SD_XL
|
||||
endpoint = DECODE_ENDPOINT_SD_XL
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.13025
|
||||
shift_factor = None
|
||||
@@ -527,7 +530,7 @@ class RemoteAutoencoderKLFluxSlowTests(
|
||||
unittest.TestCase,
|
||||
):
|
||||
channels = 16
|
||||
endpoint = ENDPOINT_FLUX
|
||||
endpoint = DECODE_ENDPOINT_FLUX
|
||||
dtype = torch.bfloat16
|
||||
scaling_factor = 0.3611
|
||||
shift_factor = 0.1159
|
||||
|
||||
224
tests/remote/test_remote_encode.py
Normal file
224
tests/remote/test_remote_encode.py
Normal file
@@ -0,0 +1,224 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.constants import (
|
||||
DECODE_ENDPOINT_FLUX,
|
||||
DECODE_ENDPOINT_SD_V1,
|
||||
DECODE_ENDPOINT_SD_XL,
|
||||
ENCODE_ENDPOINT_FLUX,
|
||||
ENCODE_ENDPOINT_SD_V1,
|
||||
ENCODE_ENDPOINT_SD_XL,
|
||||
)
|
||||
from diffusers.utils.remote_utils import (
|
||||
remote_decode,
|
||||
remote_encode,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
slow,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
IMAGE = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true"
|
||||
|
||||
|
||||
class RemoteAutoencoderKLEncodeMixin:
|
||||
channels: int = None
|
||||
endpoint: str = None
|
||||
decode_endpoint: str = None
|
||||
dtype: torch.dtype = None
|
||||
scaling_factor: float = None
|
||||
shift_factor: float = None
|
||||
image: PIL.Image.Image = None
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
if self.image is None:
|
||||
self.image = load_image(IMAGE)
|
||||
inputs = {
|
||||
"endpoint": self.endpoint,
|
||||
"image": self.image,
|
||||
"scaling_factor": self.scaling_factor,
|
||||
"shift_factor": self.shift_factor,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_image_input(self):
|
||||
inputs = self.get_dummy_inputs()
|
||||
height, width = inputs["image"].height, inputs["image"].width
|
||||
output = remote_encode(**inputs)
|
||||
self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8])
|
||||
decoded = remote_decode(
|
||||
tensor=output,
|
||||
endpoint=self.decode_endpoint,
|
||||
scaling_factor=self.scaling_factor,
|
||||
shift_factor=self.shift_factor,
|
||||
image_format="png",
|
||||
)
|
||||
self.assertEqual(decoded.height, height)
|
||||
self.assertEqual(decoded.width, width)
|
||||
# image_slice = torch.from_numpy(np.array(inputs["image"])[0, -3:, -3:].flatten())
|
||||
# decoded_slice = torch.from_numpy(np.array(decoded)[0, -3:, -3:].flatten())
|
||||
# TODO: how to test this? encode->decode is lossy. expected slice of encoded latent?
|
||||
|
||||
|
||||
class RemoteAutoencoderKLSDv1Tests(
|
||||
RemoteAutoencoderKLEncodeMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
channels = 4
|
||||
endpoint = ENCODE_ENDPOINT_SD_V1
|
||||
decode_endpoint = DECODE_ENDPOINT_SD_V1
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.18215
|
||||
shift_factor = None
|
||||
|
||||
|
||||
class RemoteAutoencoderKLSDXLTests(
|
||||
RemoteAutoencoderKLEncodeMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
channels = 4
|
||||
endpoint = ENCODE_ENDPOINT_SD_XL
|
||||
decode_endpoint = DECODE_ENDPOINT_SD_XL
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.13025
|
||||
shift_factor = None
|
||||
|
||||
|
||||
class RemoteAutoencoderKLFluxTests(
|
||||
RemoteAutoencoderKLEncodeMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
channels = 16
|
||||
endpoint = ENCODE_ENDPOINT_FLUX
|
||||
decode_endpoint = DECODE_ENDPOINT_FLUX
|
||||
dtype = torch.bfloat16
|
||||
scaling_factor = 0.3611
|
||||
shift_factor = 0.1159
|
||||
|
||||
|
||||
class RemoteAutoencoderKLEncodeSlowTestMixin:
|
||||
channels: int = 4
|
||||
endpoint: str = None
|
||||
decode_endpoint: str = None
|
||||
dtype: torch.dtype = None
|
||||
scaling_factor: float = None
|
||||
shift_factor: float = None
|
||||
image: PIL.Image.Image = None
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
if self.image is None:
|
||||
self.image = load_image(IMAGE)
|
||||
inputs = {
|
||||
"endpoint": self.endpoint,
|
||||
"image": self.image,
|
||||
"scaling_factor": self.scaling_factor,
|
||||
"shift_factor": self.shift_factor,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_multi_res(self):
|
||||
inputs = self.get_dummy_inputs()
|
||||
for height in {
|
||||
320,
|
||||
512,
|
||||
640,
|
||||
704,
|
||||
896,
|
||||
1024,
|
||||
1208,
|
||||
1384,
|
||||
1536,
|
||||
1608,
|
||||
1864,
|
||||
2048,
|
||||
}:
|
||||
for width in {
|
||||
320,
|
||||
512,
|
||||
640,
|
||||
704,
|
||||
896,
|
||||
1024,
|
||||
1208,
|
||||
1384,
|
||||
1536,
|
||||
1608,
|
||||
1864,
|
||||
2048,
|
||||
}:
|
||||
inputs["image"] = inputs["image"].resize(
|
||||
(
|
||||
width,
|
||||
height,
|
||||
)
|
||||
)
|
||||
output = remote_encode(**inputs)
|
||||
self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8])
|
||||
decoded = remote_decode(
|
||||
tensor=output,
|
||||
endpoint=self.decode_endpoint,
|
||||
scaling_factor=self.scaling_factor,
|
||||
shift_factor=self.shift_factor,
|
||||
image_format="png",
|
||||
)
|
||||
self.assertEqual(decoded.height, height)
|
||||
self.assertEqual(decoded.width, width)
|
||||
decoded.save(f"test_multi_res_{height}_{width}.png")
|
||||
|
||||
|
||||
@slow
|
||||
class RemoteAutoencoderKLSDv1SlowTests(
|
||||
RemoteAutoencoderKLEncodeSlowTestMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
endpoint = ENCODE_ENDPOINT_SD_V1
|
||||
decode_endpoint = DECODE_ENDPOINT_SD_V1
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.18215
|
||||
shift_factor = None
|
||||
|
||||
|
||||
@slow
|
||||
class RemoteAutoencoderKLSDXLSlowTests(
|
||||
RemoteAutoencoderKLEncodeSlowTestMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
endpoint = ENCODE_ENDPOINT_SD_XL
|
||||
decode_endpoint = DECODE_ENDPOINT_SD_XL
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.13025
|
||||
shift_factor = None
|
||||
|
||||
|
||||
@slow
|
||||
class RemoteAutoencoderKLFluxSlowTests(
|
||||
RemoteAutoencoderKLEncodeSlowTestMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
channels = 16
|
||||
endpoint = ENCODE_ENDPOINT_FLUX
|
||||
decode_endpoint = DECODE_ENDPOINT_FLUX
|
||||
dtype = torch.bfloat16
|
||||
scaling_factor = 0.3611
|
||||
shift_factor = 0.1159
|
||||
Reference in New Issue
Block a user