From 638d2bbcd9bd1e99f91d35b6b497b47b270f1ba8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 12:12:29 +0200 Subject: [PATCH] [DiffusionPipeline] Deprecate not throwing error when loading non-existant variant (#4011) * Deprecate variant nicely * make style * Apply suggestions from code review Co-authored-by: Sayak Paul * Apply suggestions from code review Co-authored-by: Pedro Cuenca --------- Co-authored-by: Sayak Paul Co-authored-by: Pedro Cuenca --- src/diffusers/pipelines/pipeline_utils.py | 9 ++++++++ tests/pipelines/test_pipelines.py | 25 +++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 87c0f711a3..29c859d087 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1213,6 +1213,15 @@ class DiffusionPipeline(ConfigMixin): filenames = {sibling.rfilename for sibling in info.siblings} model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + if len(variant_filenames) == 0 and variant is not None: + deprecation_message = ( + f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." + f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`" + "if such variant modeling files are not available. Doing so will lead to an error in v0.22.0 as defaulting to non-variant" + "modeling files is deprecated." + ) + deprecate("no variant default", "0.22.0", deprecation_message, standard_warn=False) + # remove ignored filenames model_filenames = set(model_filenames) - set(ignore_filenames) variant_filenames = set(variant_filenames) - set(ignore_filenames) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 7ee2c632e6..c4241f1245 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -14,6 +14,7 @@ # limitations under the License. import gc +import glob import json import os import random @@ -56,6 +57,7 @@ from diffusers import ( UniPCMultistepScheduler, logging, ) +from diffusers.pipelines.pipeline_utils import variant_compatible_siblings from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import ( CONFIG_NAME, @@ -1361,6 +1363,29 @@ class PipelineFastTests(unittest.TestCase): assert sd.config.safety_checker != (None, None) assert sd.config.feature_extractor != (None, None) + def test_warning_no_variant_available(self): + variant = "fp16" + with self.assertWarns(FutureWarning) as warning_context: + cached_folder = StableDiffusionPipeline.download( + "hf-internal-testing/diffusers-stable-diffusion-tiny-all", variant=variant + ) + + assert "but no such modeling files are available" in str(warning_context.warning) + assert variant in str(warning_context.warning) + + def get_all_filenames(directory): + filenames = glob.glob(directory + "/**", recursive=True) + filenames = [f for f in filenames if os.path.isfile(f)] + return filenames + + filenames = get_all_filenames(str(cached_folder)) + + all_model_files, variant_model_files = variant_compatible_siblings(filenames, variant=variant) + + # make sure that none of the model names are variant model names + assert len(variant_model_files) == 0 + assert len(all_model_files) > 0 + @slow @require_torch_gpu