mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[DiffusionPipeline] Deprecate not throwing error when loading non-existant variant (#4011)
* Deprecate variant nicely * make style * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
@@ -1213,6 +1213,15 @@ class DiffusionPipeline(ConfigMixin):
|
||||
filenames = {sibling.rfilename for sibling in info.siblings}
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
|
||||
if len(variant_filenames) == 0 and variant is not None:
|
||||
deprecation_message = (
|
||||
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
|
||||
f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
|
||||
"if such variant modeling files are not available. Doing so will lead to an error in v0.22.0 as defaulting to non-variant"
|
||||
"modeling files is deprecated."
|
||||
)
|
||||
deprecate("no variant default", "0.22.0", deprecation_message, standard_warn=False)
|
||||
|
||||
# remove ignored filenames
|
||||
model_filenames = set(model_filenames) - set(ignore_filenames)
|
||||
variant_filenames = set(variant_filenames) - set(ignore_filenames)
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
@@ -56,6 +57,7 @@ from diffusers import (
|
||||
UniPCMultistepScheduler,
|
||||
logging,
|
||||
)
|
||||
from diffusers.pipelines.pipeline_utils import variant_compatible_siblings
|
||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from diffusers.utils import (
|
||||
CONFIG_NAME,
|
||||
@@ -1361,6 +1363,29 @@ class PipelineFastTests(unittest.TestCase):
|
||||
assert sd.config.safety_checker != (None, None)
|
||||
assert sd.config.feature_extractor != (None, None)
|
||||
|
||||
def test_warning_no_variant_available(self):
|
||||
variant = "fp16"
|
||||
with self.assertWarns(FutureWarning) as warning_context:
|
||||
cached_folder = StableDiffusionPipeline.download(
|
||||
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", variant=variant
|
||||
)
|
||||
|
||||
assert "but no such modeling files are available" in str(warning_context.warning)
|
||||
assert variant in str(warning_context.warning)
|
||||
|
||||
def get_all_filenames(directory):
|
||||
filenames = glob.glob(directory + "/**", recursive=True)
|
||||
filenames = [f for f in filenames if os.path.isfile(f)]
|
||||
return filenames
|
||||
|
||||
filenames = get_all_filenames(str(cached_folder))
|
||||
|
||||
all_model_files, variant_model_files = variant_compatible_siblings(filenames, variant=variant)
|
||||
|
||||
# make sure that none of the model names are variant model names
|
||||
assert len(variant_model_files) == 0
|
||||
assert len(all_model_files) > 0
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user