1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Tests] parallelize (#3078)

* [Tests] parallelize

* finish folder structuring

* Parallelize tests more

* Correct saving of pipelines

* make sure logging level is correct

* try again

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Patrick von Platen
2023-04-13 14:32:57 +02:00
committed by GitHub
parent e748b3c6e1
commit 3a9d7d9758
63 changed files with 109 additions and 84 deletions

View File

@@ -21,22 +21,27 @@ jobs:
fail-fast: false
matrix:
config:
- name: Fast PyTorch CPU tests on Ubuntu
framework: pytorch
- name: Fast PyTorch Pipeline CPU tests
framework: pytorch_pipelines
runner: docker-cpu
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu
- name: Fast Flax CPU tests on Ubuntu
report: torch_cpu_pipelines
- name: Fast PyTorch Models & Schedulers CPU tests
framework: pytorch_models
runner: docker-cpu
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_models_schedulers
- name: Fast Flax CPU tests
framework: flax
runner: docker-cpu
image: diffusers/diffusers-flax-cpu
report: flax_cpu
- name: Fast ONNXRuntime CPU tests on Ubuntu
- name: Fast ONNXRuntime CPU tests
framework: onnxruntime
runner: docker-cpu
image: diffusers/diffusers-onnxruntime-cpu
report: onnx_cpu
- name: PyTorch Example CPU tests on Ubuntu
- name: PyTorch Example CPU tests
framework: pytorch_examples
runner: docker-cpu
image: diffusers/diffusers-pytorch-cpu
@@ -71,13 +76,21 @@ jobs:
run: |
python utils/print_env.py
- name: Run fast PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch' }}
- name: Run fast PyTorch Pipeline CPU tests
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
tests/pipelines
- name: Run fast PyTorch Model Scheduler CPU tests
if: ${{ matrix.config.framework == 'pytorch_models' }}
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/models tests/schedulers tests/others
- name: Run fast Flax TPU tests
if: ${{ matrix.config.framework == 'flax' }}
@@ -85,7 +98,7 @@ jobs:
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Flax" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
tests
- name: Run fast ONNXRuntime CPU tests
if: ${{ matrix.config.framework == 'onnxruntime' }}

View File

@@ -20,7 +20,7 @@ import torch
from diffusers import UNet1DModel
from diffusers.utils import floats_tensor, slow, torch_device
from ..test_modeling_common import ModelTesterMixin
from .test_modeling_common import ModelTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -22,7 +22,7 @@ import torch
from diffusers import UNet2DModel
from diffusers.utils import floats_tensor, logging, slow, torch_all_close, torch_device
from ..test_modeling_common import ModelTesterMixin
from .test_modeling_common import ModelTesterMixin
logger = logging.get_logger(__name__)

View File

@@ -34,7 +34,7 @@ from diffusers.utils import (
)
from diffusers.utils.import_utils import is_xformers_available
from ..test_modeling_common import ModelTesterMixin
from .test_modeling_common import ModelTesterMixin
logger = logging.get_logger(__name__)

View File

@@ -30,7 +30,7 @@ from diffusers.utils import (
)
from diffusers.utils.import_utils import is_xformers_available
from ..test_modeling_common import ModelTesterMixin
from .test_modeling_common import ModelTesterMixin
logger = logging.get_logger(__name__)

View File

@@ -22,7 +22,7 @@ from parameterized import parameterized
from diffusers import AutoencoderKL
from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slow, torch_all_close, torch_device
from ..test_modeling_common import ModelTesterMixin
from .test_modeling_common import ModelTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -4,7 +4,7 @@ from diffusers import FlaxAutoencoderKL
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax
from ..test_modeling_common_flax import FlaxModelTesterMixin
from .test_modeling_common_flax import FlaxModelTesterMixin
if is_flax_available():

View File

@@ -20,7 +20,7 @@ import torch
from diffusers import VQModel
from diffusers.utils import floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .test_modeling_common import ModelTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -141,6 +141,8 @@ class ConfigTester(unittest.TestCase):
def test_load_ddim_from_pndm(self):
logger = logging.get_logger("diffusers.configuration_utils")
# 30 for warning
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
ddim = DDIMScheduler.from_pretrained(
@@ -153,6 +155,8 @@ class ConfigTester(unittest.TestCase):
def test_load_euler_from_pndm(self):
logger = logging.get_logger("diffusers.configuration_utils")
# 30 for warning
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
euler = EulerDiscreteScheduler.from_pretrained(
@@ -165,6 +169,8 @@ class ConfigTester(unittest.TestCase):
def test_load_euler_ancestral_from_pndm(self):
logger = logging.get_logger("diffusers.configuration_utils")
# 30 for warning
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
euler = EulerAncestralDiscreteScheduler.from_pretrained(
@@ -177,6 +183,8 @@ class ConfigTester(unittest.TestCase):
def test_load_pndm(self):
logger = logging.get_logger("diffusers.configuration_utils")
# 30 for warning
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pndm = PNDMScheduler.from_pretrained(
@@ -189,6 +197,8 @@ class ConfigTester(unittest.TestCase):
def test_overwrite_config_on_load(self):
logger = logging.get_logger("diffusers.configuration_utils")
# 30 for warning
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
ddpm = DDPMScheduler.from_pretrained(
@@ -212,6 +222,8 @@ class ConfigTester(unittest.TestCase):
def test_load_dpmsolver(self):
logger = logging.get_logger("diffusers.configuration_utils")
# 30 for warning
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
dpm = DPMSolverMultistepScheduler.from_pretrained(

View File

@@ -167,4 +167,4 @@ class DeprecateTester(unittest.TestCase):
with self.assertWarns(FutureWarning) as warning:
deprecate(("deprecated_arg", self.higher_version, "This message is better!!!"), standard_warn=False)
assert str(warning.warning) == "This message is better!!!"
assert "diffusers/tests/test_utils.py" in warning.filename
assert "diffusers/tests/others/test_utils.py" in warning.filename

View File

@@ -28,8 +28,8 @@ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -38,8 +38,8 @@ from diffusers import (
)
from diffusers.utils import slow, torch_device
from ...pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):

View File

@@ -23,8 +23,8 @@ from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ...pipeline_params import UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS, UNCONDITIONAL_AUDIO_GENERATION_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS, UNCONDITIONAL_AUDIO_GENERATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -21,8 +21,8 @@ import torch
from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel
from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
from ...pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -23,11 +23,11 @@ from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DPMSolverMultis
from diffusers.utils import is_xformers_available, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from ...pipeline_params import (
from ..pipeline_params import (
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS,
CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS,
)
from ...test_pipelines_common import PipelineTesterMixin
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -23,8 +23,8 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, LDMTextToImagePipeline, UNet2DConditionModel
from diffusers.utils.testing_utils import load_numpy, nightly, require_torch_gpu, slow, torch_device
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -27,8 +27,8 @@ from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
from diffusers.utils import floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from ...pipeline_params import IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -22,8 +22,8 @@ import torch
from diffusers import RePaintPipeline, RePaintScheduler, UNet2DModel
from diffusers.utils.testing_utils import load_image, load_numpy, nightly, require_torch_gpu, skip_mps, torch_device
from ...pipeline_params import IMAGE_INPAINTING_BATCH_PARAMS, IMAGE_INPAINTING_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import IMAGE_INPAINTING_BATCH_PARAMS, IMAGE_INPAINTING_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -24,8 +24,8 @@ from diffusers.pipelines.spectrogram_diffusion import SpectrogramContEncoder, Sp
from diffusers.utils import require_torch_gpu, skip_mps, slow, torch_device
from diffusers.utils.testing_utils import require_note_seq, require_onnxruntime
from ...pipeline_params import TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS, TOKENS_TO_AUDIO_GENERATION_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS, TOKENS_TO_AUDIO_GENERATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -25,8 +25,8 @@ from diffusers import AutoencoderKL, CycleDiffusionPipeline, DDIMScheduler, UNet
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ...pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -29,7 +29,7 @@ from diffusers import (
)
from diffusers.utils.testing_utils import is_onnx_available, nightly, require_onnxruntime, require_torch_gpu
from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin
from ..test_pipelines_onnx_common import OnnxPipelineTesterMixin
if is_onnx_available():

View File

@@ -35,7 +35,7 @@ from diffusers.utils.testing_utils import (
require_torch_gpu,
)
from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin
from ..test_pipelines_onnx_common import OnnxPipelineTesterMixin
if is_onnx_available():

View File

@@ -26,7 +26,7 @@ from diffusers.utils.testing_utils import (
require_torch_gpu,
)
from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin
from ..test_pipelines_onnx_common import OnnxPipelineTesterMixin
if is_onnx_available():

View File

@@ -36,7 +36,7 @@ from diffusers.utils.testing_utils import (
require_torch_gpu,
)
from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin
from ..test_pipelines_onnx_common import OnnxPipelineTesterMixin
if is_onnx_available():

View File

@@ -40,8 +40,8 @@ from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
from ...models.test_models_unet_2d_condition import create_lora_layers
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -33,8 +33,8 @@ from diffusers.utils import load_image, load_numpy, randn_tensor, slow, torch_de
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import require_torch_gpu
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
class StableDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):

View File

@@ -32,8 +32,8 @@ from diffusers import (
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from ...pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -34,8 +34,8 @@ from diffusers.image_processor import VaeImageProcessor
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ...pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -34,8 +34,8 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint impo
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from ...pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -34,8 +34,8 @@ from diffusers import (
from diffusers.utils import floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from ...pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -31,8 +31,8 @@ from diffusers import (
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -32,8 +32,8 @@ from diffusers import (
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -33,8 +33,8 @@ from diffusers import (
from diffusers.utils import load_numpy, slow, torch_device
from diffusers.utils.testing_utils import load_image, load_pt, require_torch_gpu, skip_mps
from ...pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -29,8 +29,8 @@ from diffusers import (
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -35,8 +35,8 @@ from diffusers import (
from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -29,8 +29,8 @@ from diffusers import (
from diffusers.utils import load_numpy, skip_mps, slow
from diffusers.utils.testing_utils import require_torch_gpu
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@skip_mps

View File

@@ -51,8 +51,8 @@ from diffusers.utils import (
)
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ...pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -26,8 +26,8 @@ from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeli
from diffusers.utils import floats_tensor, load_image, load_numpy, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, slow
from ...pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -31,8 +31,8 @@ from diffusers import (
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from ...pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -15,8 +15,8 @@ from diffusers import (
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
from diffusers.utils.testing_utils import load_numpy, require_torch_gpu, slow, torch_device
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
class StableUnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):

View File

@@ -27,8 +27,8 @@ from diffusers.utils.testing_utils import (
torch_device,
)
from ...pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ...test_pipelines_common import (
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
assert_mean_pixel_difference,
)

View File

@@ -85,7 +85,7 @@ class PipelineTesterMixin:
raise NotImplementedError(
"You need to set the attribute `params` in the child test class. "
"`params` are checked for if all values are present in `__call__`'s signature."
" You can set `params` using one of the common set of parameters defined in`pipeline_params.py`"
" You can set `params` using one of the common set of parameters defined in `pipeline_params.py`"
" e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to "
"image pipelines, including prompts and prompt embedding overrides."
"If your pipeline's set of arguments has minor changes from one of the common sets of arguments, "

View File

@@ -28,8 +28,8 @@ from diffusers import (
)
from diffusers.utils import load_numpy, skip_mps, slow
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False

View File

@@ -20,7 +20,7 @@ import torch
from diffusers import DDIMScheduler, TextToVideoZeroPipeline
from diffusers.utils import load_pt, require_torch_gpu, slow
from ...test_pipelines_common import assert_mean_pixel_difference
from ..test_pipelines_common import assert_mean_pixel_difference
@slow

View File

@@ -25,8 +25,8 @@ from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):

View File

@@ -39,8 +39,8 @@ from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
from diffusers.utils import floats_tensor, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import load_image, require_torch_gpu, skip_mps
from ...pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS
from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
from ..pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase):