mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Loading] Ignore unneeded files (#1107)
* [Loading] Ignore unneeded files * up
This commit is contained in:
committed by
GitHub
parent
cbcd0512f0
commit
c39a511b5f
@@ -302,10 +302,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
allow_patterns = [os.path.join(k, "*") for k in folder_names]
|
||||
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
|
||||
|
||||
# make sure we don't download PyTorch weights
|
||||
ignore_patterns = "*.bin"
|
||||
|
||||
if cls != FlaxDiffusionPipeline:
|
||||
requested_pipeline_class = cls.__name__
|
||||
else:
|
||||
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
|
||||
requested_pipeline_class = (
|
||||
requested_pipeline_class
|
||||
if requested_pipeline_class.startswith("Flax")
|
||||
else "Flax" + requested_pipeline_class
|
||||
)
|
||||
|
||||
user_agent = {"pipeline_class": requested_pipeline_class}
|
||||
user_agent = http_user_agent(user_agent)
|
||||
|
||||
@@ -319,6 +328,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
else:
|
||||
@@ -337,7 +347,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
if config_dict["_class_name"].startswith("Flax")
|
||||
else "Flax" + config_dict["_class_name"]
|
||||
)
|
||||
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
||||
pipeline_class = getattr(diffusers_module, class_name)
|
||||
|
||||
# some modules can be passed directly to the init
|
||||
# in this case they are already instantiated in `kwargs`
|
||||
|
||||
@@ -395,6 +395,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||
allow_patterns = [os.path.join(k, "*") for k in folder_names]
|
||||
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
|
||||
|
||||
# make sure we don't download flax weights
|
||||
ignore_patterns = "*.msgpack"
|
||||
|
||||
if custom_pipeline is not None:
|
||||
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
|
||||
|
||||
@@ -417,6 +420,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -73,6 +73,22 @@ def test_progress_bar(capsys):
|
||||
assert captured.err == "", "Progress bar should be disabled"
|
||||
|
||||
|
||||
class DownloadTests(unittest.TestCase):
|
||||
def test_download_only_pytorch(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# pipeline has Flax weights
|
||||
_ = DiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# None of the downloaded files should be a flax file even if we have some here:
|
||||
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
|
||||
assert not any(f.endswith(".msgpack") for f in files)
|
||||
|
||||
|
||||
class CustomPipelineTests(unittest.TestCase):
|
||||
def test_load_custom_pipeline(self):
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -24,12 +26,29 @@ from diffusers.utils.testing_utils import require_flax, slow
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from diffusers import FlaxDDIMScheduler, FlaxStableDiffusionPipeline
|
||||
from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline
|
||||
from flax.jax_utils import replicate
|
||||
from flax.training.common_utils import shard
|
||||
from jax import pmap
|
||||
|
||||
|
||||
@require_flax
|
||||
class DownloadTests(unittest.TestCase):
|
||||
def test_download_only_pytorch(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# pipeline has Flax weights
|
||||
_ = FlaxDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# None of the downloaded files should be a PyTorch file even if we have some here:
|
||||
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_pytorch_model.bin
|
||||
assert not any(f.endswith(".bin") for f in files)
|
||||
|
||||
|
||||
@slow
|
||||
@require_flax
|
||||
class FlaxPipelineTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user