mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[Pipeline download] Improve pipeline download for index and passed co… (#2980)
* [Pipeline download] Improve pipeline download for index and passed components * correct * add more tests * up
This commit is contained in:
committed by
Daniel Gu
parent
fe03d5bce4
commit
0aa92a53ad
@@ -134,7 +134,7 @@ class AudioPipelineOutput(BaseOutput):
|
||||
audios: np.ndarray
|
||||
|
||||
|
||||
def is_safetensors_compatible(filenames, variant=None) -> bool:
|
||||
def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
|
||||
"""
|
||||
Checking for safetensors compatibility:
|
||||
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
|
||||
@@ -150,9 +150,14 @@ def is_safetensors_compatible(filenames, variant=None) -> bool:
|
||||
|
||||
sf_filenames = set()
|
||||
|
||||
passed_components = passed_components or []
|
||||
|
||||
for filename in filenames:
|
||||
_, extension = os.path.splitext(filename)
|
||||
|
||||
if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
|
||||
continue
|
||||
|
||||
if extension == ".bin":
|
||||
pt_filenames.append(filename)
|
||||
elif extension == ".safetensors":
|
||||
@@ -163,10 +168,8 @@ def is_safetensors_compatible(filenames, variant=None) -> bool:
|
||||
path, filename = os.path.split(filename)
|
||||
filename, extension = os.path.splitext(filename)
|
||||
|
||||
if filename == "pytorch_model":
|
||||
filename = "model"
|
||||
elif filename == f"pytorch_model.{variant}":
|
||||
filename = f"model.{variant}"
|
||||
if filename.startswith("pytorch_model"):
|
||||
filename = filename.replace("pytorch_model", "model")
|
||||
else:
|
||||
filename = filename
|
||||
|
||||
@@ -196,24 +199,51 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
|
||||
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
||||
# .bin, .safetensors, ...
|
||||
weight_suffixs = [w.split(".")[-1] for w in weight_names]
|
||||
|
||||
variant_file_regex = (
|
||||
re.compile(f"({'|'.join(weight_prefixes)})(.{variant}.)({'|'.join(weight_suffixs)})")
|
||||
if variant is not None
|
||||
else None
|
||||
)
|
||||
non_variant_file_regex = re.compile(f"{'|'.join(weight_names)}")
|
||||
# -00001-of-00002
|
||||
transformers_index_format = "\d{5}-of-\d{5}"
|
||||
|
||||
if variant is not None:
|
||||
variant_filenames = {f for f in filenames if variant_file_regex.match(f.split("/")[-1]) is not None}
|
||||
# `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetenstors`
|
||||
variant_file_re = re.compile(
|
||||
f"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
|
||||
)
|
||||
# `text_encoder/pytorch_model.bin.index.fp16.json`
|
||||
variant_index_re = re.compile(
|
||||
f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
||||
)
|
||||
|
||||
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetenstors`
|
||||
non_variant_file_re = re.compile(
|
||||
f"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
|
||||
)
|
||||
# `text_encoder/pytorch_model.bin.index.json`
|
||||
non_variant_index_re = re.compile(f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
|
||||
|
||||
if variant is not None:
|
||||
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
|
||||
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
|
||||
variant_filenames = variant_weights | variant_indexes
|
||||
else:
|
||||
variant_filenames = set()
|
||||
|
||||
non_variant_filenames = {f for f in filenames if non_variant_file_regex.match(f.split("/")[-1]) is not None}
|
||||
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
|
||||
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
|
||||
non_variant_filenames = non_variant_weights | non_variant_indexes
|
||||
|
||||
# all variant filenames will be used by default
|
||||
usable_filenames = set(variant_filenames)
|
||||
|
||||
def convert_to_variant(filename):
|
||||
if "index" in filename:
|
||||
variant_filename = filename.replace("index", f"index.{variant}")
|
||||
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
|
||||
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
|
||||
else:
|
||||
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
|
||||
return variant_filename
|
||||
|
||||
for f in non_variant_filenames:
|
||||
variant_filename = f"{f.split('.')[0]}.{variant}.{f.split('.')[1]}"
|
||||
variant_filename = convert_to_variant(f)
|
||||
if variant_filename not in usable_filenames:
|
||||
usable_filenames.add(f)
|
||||
|
||||
@@ -292,6 +322,27 @@ def get_class_obj_and_candidates(library_name, class_name, importable_classes, p
|
||||
return class_obj, class_candidates
|
||||
|
||||
|
||||
def _get_pipeline_class(class_obj, config, custom_pipeline=None, cache_dir=None, revision=None):
|
||||
if custom_pipeline is not None:
|
||||
if custom_pipeline.endswith(".py"):
|
||||
path = Path(custom_pipeline)
|
||||
# decompose into folder & file
|
||||
file_name = path.name
|
||||
custom_pipeline = path.parent.absolute()
|
||||
else:
|
||||
file_name = CUSTOM_PIPELINE_FILE_NAME
|
||||
|
||||
return get_class_from_dynamic_module(
|
||||
custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision
|
||||
)
|
||||
|
||||
if class_obj != DiffusionPipeline:
|
||||
return class_obj
|
||||
|
||||
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
|
||||
return getattr(diffusers_module, config["_class_name"])
|
||||
|
||||
|
||||
def load_sub_model(
|
||||
library_name: str,
|
||||
class_name: str,
|
||||
@@ -779,7 +830,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
@@ -794,8 +845,11 @@ class DiffusionPipeline(ConfigMixin):
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
from_flax=from_flax,
|
||||
use_safetensors=use_safetensors,
|
||||
custom_pipeline=custom_pipeline,
|
||||
custom_revision=custom_revision,
|
||||
variant=variant,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
@@ -810,29 +864,17 @@ class DiffusionPipeline(ConfigMixin):
|
||||
for folder in os.listdir(cached_folder):
|
||||
folder_path = os.path.join(cached_folder, folder)
|
||||
is_folder = os.path.isdir(folder_path) and folder in config_dict
|
||||
variant_exists = is_folder and any(path.split(".")[1] == variant for path in os.listdir(folder_path))
|
||||
variant_exists = is_folder and any(
|
||||
p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
|
||||
)
|
||||
if variant_exists:
|
||||
model_variants[folder] = variant
|
||||
|
||||
# 3. Load the pipeline class, if using custom module then load it from the hub
|
||||
# if we load from explicit class, let's use it
|
||||
if custom_pipeline is not None:
|
||||
if custom_pipeline.endswith(".py"):
|
||||
path = Path(custom_pipeline)
|
||||
# decompose into folder & file
|
||||
file_name = path.name
|
||||
custom_pipeline = path.parent.absolute()
|
||||
else:
|
||||
file_name = CUSTOM_PIPELINE_FILE_NAME
|
||||
|
||||
pipeline_class = get_class_from_dynamic_module(
|
||||
custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=custom_revision
|
||||
)
|
||||
elif cls != DiffusionPipeline:
|
||||
pipeline_class = cls
|
||||
else:
|
||||
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
|
||||
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
||||
pipeline_class = _get_pipeline_class(
|
||||
cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision
|
||||
)
|
||||
|
||||
# DEPRECATED: To be removed in 1.0.0
|
||||
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
|
||||
@@ -1095,6 +1137,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
revision = kwargs.pop("revision", None)
|
||||
from_flax = kwargs.pop("from_flax", False)
|
||||
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
||||
custom_revision = kwargs.pop("custom_revision", None)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
@@ -1153,7 +1196,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
# this enables downloading schedulers, tokenizers, ...
|
||||
allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names]
|
||||
# also allow downloading config.json files with the model
|
||||
allow_patterns += [os.path.join(k, "*.json") for k in model_folder_names]
|
||||
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
|
||||
|
||||
allow_patterns += [
|
||||
SCHEDULER_CONFIG_NAME,
|
||||
@@ -1162,17 +1205,28 @@ class DiffusionPipeline(ConfigMixin):
|
||||
CUSTOM_PIPELINE_FILE_NAME,
|
||||
]
|
||||
|
||||
# retrieve passed components that should not be downloaded
|
||||
pipeline_class = _get_pipeline_class(
|
||||
cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision
|
||||
)
|
||||
expected_components, _ = cls._get_signature_keys(pipeline_class)
|
||||
passed_components = [k for k in expected_components if k in kwargs]
|
||||
|
||||
if (
|
||||
use_safetensors
|
||||
and not allow_pickle
|
||||
and not is_safetensors_compatible(model_filenames, variant=variant)
|
||||
and not is_safetensors_compatible(
|
||||
model_filenames, variant=variant, passed_components=passed_components
|
||||
)
|
||||
):
|
||||
raise EnvironmentError(
|
||||
f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})"
|
||||
)
|
||||
if from_flax:
|
||||
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
|
||||
elif use_safetensors and is_safetensors_compatible(model_filenames, variant=variant):
|
||||
elif use_safetensors and is_safetensors_compatible(
|
||||
model_filenames, variant=variant, passed_components=passed_components
|
||||
):
|
||||
ignore_patterns = ["*.bin", "*.msgpack"]
|
||||
|
||||
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
|
||||
@@ -1194,6 +1248,13 @@ class DiffusionPipeline(ConfigMixin):
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
||||
)
|
||||
|
||||
# Don't download any objects that are passed
|
||||
allow_patterns = [
|
||||
p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components)
|
||||
]
|
||||
# Don't download index files of forbidden patterns either
|
||||
ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns]
|
||||
|
||||
re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns]
|
||||
re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns]
|
||||
|
||||
|
||||
@@ -78,9 +78,7 @@ class DownloadTests(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
with requests_mock.mock(real_http=True) as m:
|
||||
DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
|
||||
)
|
||||
DiffusionPipeline.download("hf-internal-testing/tiny-stable-diffusion-pipe", cache_dir=tmpdirname)
|
||||
|
||||
download_requests = [r.method for r in m.request_history]
|
||||
assert download_requests.count("HEAD") == 15, "15 calls to files"
|
||||
@@ -101,6 +99,55 @@ class DownloadTests(unittest.TestCase):
|
||||
len(cache_requests) == 2
|
||||
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
|
||||
|
||||
def test_less_downloads_passed_object(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
cached_folder = DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
|
||||
)
|
||||
|
||||
# make sure safety checker is not downloaded
|
||||
assert "safety_checker" not in os.listdir(cached_folder)
|
||||
|
||||
# make sure rest is downloaded
|
||||
assert "unet" in os.listdir(cached_folder)
|
||||
assert "tokenizer" in os.listdir(cached_folder)
|
||||
assert "vae" in os.listdir(cached_folder)
|
||||
assert "model_index.json" in os.listdir(cached_folder)
|
||||
assert "scheduler" in os.listdir(cached_folder)
|
||||
assert "feature_extractor" in os.listdir(cached_folder)
|
||||
|
||||
def test_less_downloads_passed_object_calls(self):
|
||||
# TODO: For some reason this test fails on MPS where no HEAD call is made.
|
||||
if torch_device == "mps":
|
||||
return
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
with requests_mock.mock(real_http=True) as m:
|
||||
DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
|
||||
)
|
||||
|
||||
download_requests = [r.method for r in m.request_history]
|
||||
# 15 - 2 because no call to config or model file for `safety_checker`
|
||||
assert download_requests.count("HEAD") == 13, "13 calls to files"
|
||||
# 17 - 2 because no call to config or model file for `safety_checker`
|
||||
assert download_requests.count("GET") == 15, "13 calls to files + model_info + model_index.json"
|
||||
assert (
|
||||
len(download_requests) == 28
|
||||
), "2 calls per file (13 files) + send_telemetry, model_info and model_index.json"
|
||||
|
||||
with requests_mock.mock(real_http=True) as m:
|
||||
DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
|
||||
)
|
||||
|
||||
cache_requests = [r.method for r in m.request_history]
|
||||
assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD"
|
||||
assert cache_requests.count("GET") == 1, "model info is only GET"
|
||||
assert (
|
||||
len(cache_requests) == 2
|
||||
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
|
||||
|
||||
def test_download_only_pytorch(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# pipeline has Flax weights
|
||||
@@ -165,6 +212,54 @@ class DownloadTests(unittest.TestCase):
|
||||
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
|
||||
assert not any(f.endswith(".bin") for f in files)
|
||||
|
||||
def test_download_safetensors_index(self):
|
||||
for variant in ["fp16", None]:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe-indexes",
|
||||
cache_dir=tmpdirname,
|
||||
use_safetensors=True,
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# None of the downloaded files should be a safetensors file even if we have some here:
|
||||
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-indexes/tree/main/text_encoder
|
||||
if variant is None:
|
||||
assert not any("fp16" in f for f in files)
|
||||
else:
|
||||
model_files = [f for f in files if "safetensors" in f]
|
||||
assert all("fp16" in f for f in model_files)
|
||||
|
||||
assert len([f for f in files if ".safetensors" in f]) == 8
|
||||
assert not any(".bin" in f for f in files)
|
||||
|
||||
def test_download_bin_index(self):
|
||||
for variant in ["fp16", None]:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe-indexes",
|
||||
cache_dir=tmpdirname,
|
||||
use_safetensors=False,
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# None of the downloaded files should be a safetensors file even if we have some here:
|
||||
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-indexes/tree/main/text_encoder
|
||||
if variant is None:
|
||||
assert not any("fp16" in f for f in files)
|
||||
else:
|
||||
model_files = [f for f in files if "bin" in f]
|
||||
assert all("fp16" in f for f in model_files)
|
||||
|
||||
assert len([f for f in files if ".bin" in f]) == 8
|
||||
assert not any(".safetensors" in f for f in files)
|
||||
|
||||
def test_download_no_safety_checker(self):
|
||||
prompt = "hello"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
@@ -362,6 +457,33 @@ class DownloadTests(unittest.TestCase):
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = True
|
||||
|
||||
def test_local_save_load_index(self):
|
||||
prompt = "hello"
|
||||
for variant in [None, "fp16"]:
|
||||
for use_safe in [True, False]:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe-indexes",
|
||||
variant=variant,
|
||||
use_safetensors=use_safe,
|
||||
safety_checker=None,
|
||||
)
|
||||
pipe = pipe.to(torch_device)
|
||||
generator = torch.manual_seed(0)
|
||||
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
pipe_2 = StableDiffusionPipeline.from_pretrained(
|
||||
tmpdirname, safe_serialization=use_safe, variant=variant
|
||||
)
|
||||
pipe_2 = pipe_2.to(torch_device)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||
|
||||
assert np.max(np.abs(out - out_2)) < 1e-3
|
||||
|
||||
def test_text_inversion_download(self):
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
|
||||
|
||||
Reference in New Issue
Block a user