mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add VAE Decode endpoint slow test (#10946)
This commit is contained in:
@@ -24,6 +24,7 @@ from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.utils.remote_utils import remote_decode
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
@@ -32,6 +33,11 @@ from diffusers.video_processor import VideoProcessor
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
|
||||
|
||||
class RemoteAutoencoderKLMixin:
|
||||
shape: Tuple[int, ...] = None
|
||||
@@ -344,7 +350,7 @@ class RemoteAutoencoderKLSDv1Tests(
|
||||
512,
|
||||
512,
|
||||
)
|
||||
endpoint = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
endpoint = ENDPOINT_SD_V1
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.18215
|
||||
shift_factor = None
|
||||
@@ -368,7 +374,7 @@ class RemoteAutoencoderKLSDXLTests(
|
||||
1024,
|
||||
1024,
|
||||
)
|
||||
endpoint = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
endpoint = ENDPOINT_SD_XL
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.13025
|
||||
shift_factor = None
|
||||
@@ -392,7 +398,7 @@ class RemoteAutoencoderKLFluxTests(
|
||||
1024,
|
||||
1024,
|
||||
)
|
||||
endpoint = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
endpoint = ENDPOINT_FLUX
|
||||
dtype = torch.bfloat16
|
||||
scaling_factor = 0.3611
|
||||
shift_factor = 0.1159
|
||||
@@ -419,7 +425,7 @@ class RemoteAutoencoderKLFluxPackedTests(
|
||||
)
|
||||
height = 1024
|
||||
width = 1024
|
||||
endpoint = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
endpoint = ENDPOINT_FLUX
|
||||
dtype = torch.bfloat16
|
||||
scaling_factor = 0.3611
|
||||
shift_factor = 0.1159
|
||||
@@ -447,7 +453,7 @@ class RemoteAutoencoderKLHunyuanVideoTests(
|
||||
320,
|
||||
512,
|
||||
)
|
||||
endpoint = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
endpoint = ENDPOINT_HUNYUAN_VIDEO
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.476986
|
||||
processor_cls = VideoProcessor
|
||||
@@ -456,3 +462,72 @@ class RemoteAutoencoderKLHunyuanVideoTests(
|
||||
[149, 161, 168, 136, 150, 156, 129, 143, 149], dtype=torch.uint8
|
||||
)
|
||||
return_pt_slice = torch.tensor([0.1656, 0.2661, 0.3157, 0.0693, 0.1755, 0.2252, 0.0127, 0.1221, 0.1708])
|
||||
|
||||
|
||||
class RemoteAutoencoderKLSlowTestMixin:
|
||||
channels: int = 4
|
||||
endpoint: str = None
|
||||
dtype: torch.dtype = None
|
||||
scaling_factor: float = None
|
||||
shift_factor: float = None
|
||||
width: int = None
|
||||
height: int = None
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
inputs = {
|
||||
"endpoint": self.endpoint,
|
||||
"scaling_factor": self.scaling_factor,
|
||||
"shift_factor": self.shift_factor,
|
||||
"height": self.height,
|
||||
"width": self.width,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_multi_res(self):
|
||||
inputs = self.get_dummy_inputs()
|
||||
for height in {320, 512, 640, 704, 896, 1024, 1208, 1384, 1536, 1608, 1864, 2048}:
|
||||
for width in {320, 512, 640, 704, 896, 1024, 1208, 1384, 1536, 1608, 1864, 2048}:
|
||||
inputs["tensor"] = torch.randn(
|
||||
(1, self.channels, height // 8, width // 8),
|
||||
device=torch_device,
|
||||
dtype=self.dtype,
|
||||
generator=torch.Generator(torch_device).manual_seed(13),
|
||||
)
|
||||
inputs["height"] = height
|
||||
inputs["width"] = width
|
||||
output = remote_decode(output_type="pt", partial_postprocess=True, **inputs)
|
||||
output.save(f"test_multi_res_{height}_{width}.png")
|
||||
|
||||
|
||||
@slow
|
||||
class RemoteAutoencoderKLSDv1SlowTests(
|
||||
RemoteAutoencoderKLSlowTestMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
endpoint = ENDPOINT_SD_V1
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.18215
|
||||
shift_factor = None
|
||||
|
||||
|
||||
@slow
|
||||
class RemoteAutoencoderKLSDXLSlowTests(
|
||||
RemoteAutoencoderKLSlowTestMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
endpoint = ENDPOINT_SD_XL
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.13025
|
||||
shift_factor = None
|
||||
|
||||
|
||||
@slow
|
||||
class RemoteAutoencoderKLFluxSlowTests(
|
||||
RemoteAutoencoderKLSlowTestMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
channels = 16
|
||||
endpoint = ENDPOINT_FLUX
|
||||
dtype = torch.bfloat16
|
||||
scaling_factor = 0.3611
|
||||
shift_factor = 0.1159
|
||||
|
||||
Reference in New Issue
Block a user