1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[ONNX] Don't download ONNX model by default (#4338)

* [Download] Don't download ONNX weights by default

* [Download] Don't download ONNX weights by default

* [Download] Don't download ONNX weights by default

* fix more

* finish

* finish

* finish
This commit is contained in:
Patrick von Platen
2023-07-28 14:02:48 +02:00
committed by GitHub
parent c7250f2b8a
commit 306a7bd047
7 changed files with 74 additions and 1 deletions

View File

@@ -494,6 +494,7 @@ class DiffusionPipeline(ConfigMixin):
_optional_components = []
_exclude_from_cpu_offload = []
_load_connected_pipes = False
_is_onnx = False
def register_modules(self, **kwargs):
# import it here to avoid circular import
@@ -839,6 +840,11 @@ class DiffusionPipeline(ConfigMixin):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
use_onnx (`bool`, *optional*, defaults to `None`):
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
with `.onnx` and `.pb`.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
@@ -1268,6 +1274,15 @@ class DiffusionPipeline(ConfigMixin):
variant (`str`, *optional*):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
use_onnx (`bool`, *optional*, defaults to `False`):
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
with `.onnx` and `.pb`.
Returns:
`os.PathLike`:
@@ -1293,6 +1308,7 @@ class DiffusionPipeline(ConfigMixin):
custom_revision = kwargs.pop("custom_revision", None)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
if use_safetensors and not is_safetensors_available():
@@ -1364,7 +1380,7 @@ class DiffusionPipeline(ConfigMixin):
pretrained_model_name, use_auth_token, variant, revision, model_filenames
)
model_folder_names = {os.path.split(f)[0] for f in model_filenames}
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)
@@ -1411,6 +1427,10 @@ class DiffusionPipeline(ConfigMixin):
):
ignore_patterns = ["*.bin", "*.msgpack"]
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
if (
@@ -1423,6 +1443,10 @@ class DiffusionPipeline(ConfigMixin):
else:
ignore_patterns = ["*.safetensors", "*.msgpack"]
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:

View File

@@ -41,6 +41,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor
_optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True
def __init__(
self,

View File

@@ -98,6 +98,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor
_optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True
def __init__(
self,

View File

@@ -90,6 +90,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor
_optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True
def __init__(
self,

View File

@@ -67,6 +67,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True
vae_encoder: OnnxRuntimeModel
vae_decoder: OnnxRuntimeModel

View File

@@ -46,6 +46,8 @@ def preprocess(image):
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
_is_onnx = True
def __init__(
self,
vae: OnnxRuntimeModel,

View File

@@ -310,6 +310,49 @@ class DownloadTests(unittest.TestCase):
assert len([f for f in files if ".bin" in f]) == 8
assert not any(".safetensors" in f for f in files)
def test_download_no_openvino_by_default(self):
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-open-vino",
cache_dir=tmpdirname,
)
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]
# make sure that by default no openvino weights are downloaded
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
assert not any("openvino_" in f for f in files)
def test_download_no_onnx_by_default(self):
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
cache_dir=tmpdirname,
)
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]
# make sure that by default no onnx weights are downloaded
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
assert not any((f.endswith(".onnx") or f.endswith(".pb")) for f in files)
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
cache_dir=tmpdirname,
use_onnx=True,
)
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]
# if `use_onnx` is specified make sure weights are downloaded
assert any((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
assert any((f.endswith(".onnx")) for f in files)
assert any((f.endswith(".pb")) for f in files)
def test_download_no_safety_checker(self):
prompt = "hello"
pipe = StableDiffusionPipeline.from_pretrained(