diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index c52ba5280e..112596057d 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -31,6 +31,11 @@ jobs: runner: docker-cpu image: diffusers/diffusers-flax-cpu report: flax_cpu + - name: Fast ONNXRuntime CPU tests on Ubuntu + framework: onnxruntime + runner: docker-cpu + image: diffusers/diffusers-onnxruntime-cpu + report: onnx_cpu - name: PyTorch Example CPU tests on Ubuntu framework: pytorch_examples runner: docker-cpu diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index 882b5e4d1d..2d4875b80c 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -29,6 +29,11 @@ jobs: runner: docker-tpu image: diffusers/diffusers-flax-tpu report: flax_tpu + - name: Slow ONNXRuntime CUDA tests on Ubuntu + framework: onnxruntime + runner: docker-gpu + image: diffusers/diffusers-onnxruntime-cuda + report: onnx_cuda name: ${{ matrix.config.name }} diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml index 3baa1a0ec9..bf830959cf 100644 --- a/.github/workflows/push_tests_fast.yml +++ b/.github/workflows/push_tests_fast.yml @@ -29,6 +29,11 @@ jobs: runner: docker-cpu image: diffusers/diffusers-flax-cpu report: flax_cpu + - name: Fast ONNXRuntime CPU tests on Ubuntu + framework: onnxruntime + runner: docker-cpu + image: diffusers/diffusers-onnxruntime-cpu + report: onnx_cpu - name: PyTorch Example CPU tests on Ubuntu framework: pytorch_examples runner: docker-cpu diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 770fcba151..6bd231853f 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -63,7 +63,7 @@ if is_transformers_available(): from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME -from ..utils import FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME +from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME INDEX_FILE = "diffusion_pytorch_model.bin" @@ -176,7 +176,13 @@ def is_safetensors_compatible(filenames, variant=None) -> bool: def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]: filenames = set(sibling.rfilename for sibling in info.siblings) - weight_names = [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME] + weight_names = [ + WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + FLAX_WEIGHTS_NAME, + ONNX_WEIGHTS_NAME, + ONNX_EXTERNAL_WEIGHTS_NAME, + ] if is_transformers_available(): weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] @@ -604,7 +610,7 @@ class DiffusionPipeline(ConfigMixin): ] if from_flax: - ignore_patterns = ["*.bin", "*.safetensors", ".onnx"] + ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant): ignore_patterns = ["*.bin", "*.msgpack"]