1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Loading] Ignore unneeded files (#1107)

* [Loading] Ignore unneeded files

* up
This commit is contained in:
Patrick von Platen
2022-11-02 19:20:42 +01:00
committed by GitHub
parent cbcd0512f0
commit c39a511b5f
4 changed files with 51 additions and 2 deletions

View File

@@ -302,10 +302,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
# make sure we don't download PyTorch weights
ignore_patterns = "*.bin"
if cls != FlaxDiffusionPipeline:
requested_pipeline_class = cls.__name__
else:
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
requested_pipeline_class = (
requested_pipeline_class
if requested_pipeline_class.startswith("Flax")
else "Flax" + requested_pipeline_class
)
user_agent = {"pipeline_class": requested_pipeline_class}
user_agent = http_user_agent(user_agent)
@@ -319,6 +328,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
else:
@@ -337,7 +347,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
if config_dict["_class_name"].startswith("Flax")
else "Flax" + config_dict["_class_name"]
)
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
pipeline_class = getattr(diffusers_module, class_name)
# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`

View File

@@ -395,6 +395,9 @@ class DiffusionPipeline(ConfigMixin):
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
# make sure we don't download flax weights
ignore_patterns = "*.msgpack"
if custom_pipeline is not None:
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
@@ -417,6 +420,7 @@ class DiffusionPipeline(ConfigMixin):
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
else:

View File

@@ -73,6 +73,22 @@ def test_progress_bar(capsys):
assert captured.err == "", "Progress bar should be disabled"
class DownloadTests(unittest.TestCase):
def test_download_only_pytorch(self):
with tempfile.TemporaryDirectory() as tmpdirname:
# pipeline has Flax weights
_ = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
)
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a flax file even if we have some here:
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
assert not any(f.endswith(".msgpack") for f in files)
class CustomPipelineTests(unittest.TestCase):
def test_load_custom_pipeline(self):
pipeline = DiffusionPipeline.from_pretrained(

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import numpy as np
@@ -24,12 +26,29 @@ from diffusers.utils.testing_utils import require_flax, slow
if is_flax_available():
import jax
import jax.numpy as jnp
from diffusers import FlaxDDIMScheduler, FlaxStableDiffusionPipeline
from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from jax import pmap
@require_flax
class DownloadTests(unittest.TestCase):
def test_download_only_pytorch(self):
with tempfile.TemporaryDirectory() as tmpdirname:
# pipeline has Flax weights
_ = FlaxDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
)
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a PyTorch file even if we have some here:
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_pytorch_model.bin
assert not any(f.endswith(".bin") for f in files)
@slow
@require_flax
class FlaxPipelineTests(unittest.TestCase):