1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Improve loading pipe (#4009)

* improve loading subcomponents

* Add test for logging

* improve loading subcomponents

* make style

* make style

* fix

* finish
This commit is contained in:
Patrick von Platen
2023-07-10 18:05:40 +02:00
committed by GitHub
parent 7a91ea6c2b
commit 080ecf01b3
3 changed files with 42 additions and 4 deletions

View File

@@ -556,6 +556,7 @@ class DiffusionPipeline(ConfigMixin):
model_index_dict.pop("_class_name", None)
model_index_dict.pop("_diffusers_version", None)
model_index_dict.pop("_module", None)
model_index_dict.pop("_name_or_path", None)
expected_modules, optional_kwargs = self._get_signature_keys(self)
@@ -1013,7 +1014,7 @@ class DiffusionPipeline(ConfigMixin):
from diffusers import pipelines
# 6. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
for name, (library_name, class_name) in tqdm(init_dict.items(), desc="Loading pipeline components..."):
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
if class_name.startswith("Flax"):
class_name = class_name[4:]
@@ -1055,6 +1056,9 @@ class DiffusionPipeline(ConfigMixin):
low_cpu_mem_usage=low_cpu_mem_usage,
cached_folder=cached_folder,
)
logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
)
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
@@ -1073,8 +1077,15 @@ class DiffusionPipeline(ConfigMixin):
# 8. Instantiate the pipeline
model = pipeline_class(**init_kwargs)
# 9. Save where the model was instantiated from
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
return model
@property
def name_or_path(self) -> str:
return getattr(self.config, "_name_or_path", None)
@classmethod
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
r"""

View File

@@ -821,7 +821,12 @@ class CustomPipelineTests(unittest.TestCase):
pipe_new = CustomPipeline.from_pretrained(tmpdirname)
pipe_new.save_pretrained(tmpdirname)
assert dict(pipe_new.config) == dict(pipe.config)
conf_1 = dict(pipe.config)
conf_2 = dict(pipe_new.config)
del conf_2["_name_or_path"]
assert conf_1 == conf_2
@slow
@require_torch_gpu
@@ -1363,6 +1368,18 @@ class PipelineFastTests(unittest.TestCase):
assert sd.config.safety_checker != (None, None)
assert sd.config.feature_extractor != (None, None)
def test_name_or_path(self):
model_path = "hf-internal-testing/tiny-stable-diffusion-torch"
sd = DiffusionPipeline.from_pretrained(model_path)
assert sd.name_or_path == model_path
with tempfile.TemporaryDirectory() as tmpdirname:
sd.save_pretrained(tmpdirname)
sd = DiffusionPipeline.from_pretrained(tmpdirname)
assert sd.name_or_path == tmpdirname
def test_warning_no_variant_available(self):
variant = "fp16"
with self.assertWarns(FutureWarning) as warning_context:

View File

@@ -17,7 +17,7 @@ from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import logging
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
from diffusers.utils.testing_utils import require_torch, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch, torch_device
def to_np(tensor):
@@ -298,9 +298,19 @@ class PipelineTesterMixin:
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
logger.setLevel(diffusers.logging.INFO)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
with CaptureLogger(logger) as cap_logger:
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
for name in pipe_loaded.components.keys():
if name not in pipe_loaded._optional_components:
assert name in str(cap_logger)
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)