mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[ONNX] Don't download ONNX model by default (#4338)
* [Download] Don't download ONNX weights by default * [Download] Don't download ONNX weights by default * [Download] Don't download ONNX weights by default * fix more * finish * finish * finish
This commit is contained in:
committed by
GitHub
parent
c7250f2b8a
commit
306a7bd047
@@ -494,6 +494,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
_optional_components = []
|
||||
_exclude_from_cpu_offload = []
|
||||
_load_connected_pipes = False
|
||||
_is_onnx = False
|
||||
|
||||
def register_modules(self, **kwargs):
|
||||
# import it here to avoid circular import
|
||||
@@ -839,6 +840,11 @@ class DiffusionPipeline(ConfigMixin):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
use_onnx (`bool`, *optional*, defaults to `None`):
|
||||
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
|
||||
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
|
||||
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
|
||||
with `.onnx` and `.pb`.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
|
||||
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
|
||||
@@ -1268,6 +1274,15 @@ class DiffusionPipeline(ConfigMixin):
|
||||
variant (`str`, *optional*):
|
||||
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
|
||||
loading `from_flax`.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
use_onnx (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
|
||||
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
|
||||
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
|
||||
with `.onnx` and `.pb`.
|
||||
|
||||
Returns:
|
||||
`os.PathLike`:
|
||||
@@ -1293,6 +1308,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
custom_revision = kwargs.pop("custom_revision", None)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
@@ -1364,7 +1380,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
pretrained_model_name, use_auth_token, variant, revision, model_filenames
|
||||
)
|
||||
|
||||
model_folder_names = {os.path.split(f)[0] for f in model_filenames}
|
||||
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
|
||||
|
||||
# all filenames compatible with variant will be added
|
||||
allow_patterns = list(model_filenames)
|
||||
@@ -1411,6 +1427,10 @@ class DiffusionPipeline(ConfigMixin):
|
||||
):
|
||||
ignore_patterns = ["*.bin", "*.msgpack"]
|
||||
|
||||
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
|
||||
if not use_onnx:
|
||||
ignore_patterns += ["*.onnx", "*.pb"]
|
||||
|
||||
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
|
||||
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
|
||||
if (
|
||||
@@ -1423,6 +1443,10 @@ class DiffusionPipeline(ConfigMixin):
|
||||
else:
|
||||
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
||||
|
||||
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
|
||||
if not use_onnx:
|
||||
ignore_patterns += ["*.onnx", "*.pb"]
|
||||
|
||||
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
|
||||
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
|
||||
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
|
||||
|
||||
@@ -41,6 +41,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPImageProcessor
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_is_onnx = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -98,6 +98,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPImageProcessor
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_is_onnx = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -90,6 +90,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPImageProcessor
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_is_onnx = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -67,6 +67,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_is_onnx = True
|
||||
|
||||
vae_encoder: OnnxRuntimeModel
|
||||
vae_decoder: OnnxRuntimeModel
|
||||
|
||||
@@ -46,6 +46,8 @@ def preprocess(image):
|
||||
|
||||
|
||||
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
_is_onnx = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: OnnxRuntimeModel,
|
||||
|
||||
@@ -310,6 +310,49 @@ class DownloadTests(unittest.TestCase):
|
||||
assert len([f for f in files if ".bin" in f]) == 8
|
||||
assert not any(".safetensors" in f for f in files)
|
||||
|
||||
def test_download_no_openvino_by_default(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-stable-diffusion-open-vino",
|
||||
cache_dir=tmpdirname,
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# make sure that by default no openvino weights are downloaded
|
||||
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
|
||||
assert not any("openvino_" in f for f in files)
|
||||
|
||||
def test_download_no_onnx_by_default(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
|
||||
cache_dir=tmpdirname,
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# make sure that by default no onnx weights are downloaded
|
||||
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
|
||||
assert not any((f.endswith(".onnx") or f.endswith(".pb")) for f in files)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
|
||||
cache_dir=tmpdirname,
|
||||
use_onnx=True,
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# if `use_onnx` is specified make sure weights are downloaded
|
||||
assert any((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
|
||||
assert any((f.endswith(".onnx")) for f in files)
|
||||
assert any((f.endswith(".pb")) for f in files)
|
||||
|
||||
def test_download_no_safety_checker(self):
|
||||
prompt = "hello"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user