mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix
This commit is contained in:
@@ -43,6 +43,7 @@ if is_torch_available():
|
||||
from .testing_utils import (
|
||||
floats_tensor,
|
||||
load_image,
|
||||
load_numpy,
|
||||
parse_flag_from_env,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
|
||||
@@ -4,11 +4,14 @@ import os
|
||||
import random
|
||||
import re
|
||||
import unittest
|
||||
import urllib.parse
|
||||
from distutils.util import strtobool
|
||||
from io import StringIO
|
||||
from io import BytesIO, StringIO
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import PIL.Image
|
||||
import PIL.ImageOps
|
||||
import requests
|
||||
@@ -165,6 +168,19 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
|
||||
return image
|
||||
|
||||
|
||||
def load_numpy(path) -> np.ndarray:
|
||||
if not path.startswith("http://") or path.startswith("https://"):
|
||||
path = os.path.join(
|
||||
"https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main", urllib.parse.quote(path)
|
||||
)
|
||||
|
||||
response = requests.get(path)
|
||||
response.raise_for_status()
|
||||
array = np.load(BytesIO(response.content))
|
||||
|
||||
return array
|
||||
|
||||
|
||||
# --- pytest conf functions --- #
|
||||
|
||||
# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once
|
||||
|
||||
@@ -21,7 +21,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import UNet2DConditionModel, UNet2DModel
|
||||
from diffusers.utils import floats_tensor, require_torch_gpu, slow, torch_all_close, torch_device
|
||||
from diffusers.utils import floats_tensor, load_numpy, require_torch_gpu, slow, torch_all_close, torch_device
|
||||
from parameterized import parameterized
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
@@ -411,6 +411,9 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@slow
|
||||
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
@@ -418,11 +421,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
|
||||
batch_size, channels, height, width = shape
|
||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
image = torch.randn(batch_size, channels, height, width, device=torch_device, generator=generator, dtype=dtype)
|
||||
|
||||
image = torch.from_numpy(load_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
return image
|
||||
|
||||
def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
|
||||
@@ -437,9 +437,9 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
return model
|
||||
|
||||
def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
|
||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
return torch.randn(shape, device=torch_device, generator=generator, dtype=dtype)
|
||||
hidden_states = torch.from_numpy(load_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
return hidden_states
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user