1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[DiffusionPipeline] Deprecate not throwing error when loading non-existant variant (#4011)

* Deprecate variant nicely

* make style

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Patrick von Platen
2023-07-10 12:12:29 +02:00
parent 4dfcfaa137
commit 638d2bbcd9
2 changed files with 34 additions and 0 deletions

View File

@@ -1213,6 +1213,15 @@ class DiffusionPipeline(ConfigMixin):
filenames = {sibling.rfilename for sibling in info.siblings}
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
if len(variant_filenames) == 0 and variant is not None:
deprecation_message = (
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
"if such variant modeling files are not available. Doing so will lead to an error in v0.22.0 as defaulting to non-variant"
"modeling files is deprecated."
)
deprecate("no variant default", "0.22.0", deprecation_message, standard_warn=False)
# remove ignored filenames
model_filenames = set(model_filenames) - set(ignore_filenames)
variant_filenames = set(variant_filenames) - set(ignore_filenames)

View File

@@ -14,6 +14,7 @@
# limitations under the License.
import gc
import glob
import json
import os
import random
@@ -56,6 +57,7 @@ from diffusers import (
UniPCMultistepScheduler,
logging,
)
from diffusers.pipelines.pipeline_utils import variant_compatible_siblings
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import (
CONFIG_NAME,
@@ -1361,6 +1363,29 @@ class PipelineFastTests(unittest.TestCase):
assert sd.config.safety_checker != (None, None)
assert sd.config.feature_extractor != (None, None)
def test_warning_no_variant_available(self):
variant = "fp16"
with self.assertWarns(FutureWarning) as warning_context:
cached_folder = StableDiffusionPipeline.download(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", variant=variant
)
assert "but no such modeling files are available" in str(warning_context.warning)
assert variant in str(warning_context.warning)
def get_all_filenames(directory):
filenames = glob.glob(directory + "/**", recursive=True)
filenames = [f for f in filenames if os.path.isfile(f)]
return filenames
filenames = get_all_filenames(str(cached_folder))
all_model_files, variant_model_files = variant_compatible_siblings(filenames, variant=variant)
# make sure that none of the model names are variant model names
assert len(variant_model_files) == 0
assert len(all_model_files) > 0
@slow
@require_torch_gpu