mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Improve loading pipe (#4009)
* improve loading subcomponents * Add test for logging * improve loading subcomponents * make style * make style * fix * finish
This commit is contained in:
committed by
GitHub
parent
7a91ea6c2b
commit
080ecf01b3
@@ -556,6 +556,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
model_index_dict.pop("_class_name", None)
|
||||
model_index_dict.pop("_diffusers_version", None)
|
||||
model_index_dict.pop("_module", None)
|
||||
model_index_dict.pop("_name_or_path", None)
|
||||
|
||||
expected_modules, optional_kwargs = self._get_signature_keys(self)
|
||||
|
||||
@@ -1013,7 +1014,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
from diffusers import pipelines
|
||||
|
||||
# 6. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
for name, (library_name, class_name) in tqdm(init_dict.items(), desc="Loading pipeline components..."):
|
||||
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
if class_name.startswith("Flax"):
|
||||
class_name = class_name[4:]
|
||||
@@ -1055,6 +1056,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cached_folder=cached_folder,
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
||||
)
|
||||
|
||||
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
||||
|
||||
@@ -1073,8 +1077,15 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
# 8. Instantiate the pipeline
|
||||
model = pipeline_class(**init_kwargs)
|
||||
|
||||
# 9. Save where the model was instantiated from
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
return model
|
||||
|
||||
@property
|
||||
def name_or_path(self) -> str:
|
||||
return getattr(self.config, "_name_or_path", None)
|
||||
|
||||
@classmethod
|
||||
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
|
||||
r"""
|
||||
|
||||
@@ -821,7 +821,12 @@ class CustomPipelineTests(unittest.TestCase):
|
||||
pipe_new = CustomPipeline.from_pretrained(tmpdirname)
|
||||
pipe_new.save_pretrained(tmpdirname)
|
||||
|
||||
assert dict(pipe_new.config) == dict(pipe.config)
|
||||
conf_1 = dict(pipe.config)
|
||||
conf_2 = dict(pipe_new.config)
|
||||
|
||||
del conf_2["_name_or_path"]
|
||||
|
||||
assert conf_1 == conf_2
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@@ -1363,6 +1368,18 @@ class PipelineFastTests(unittest.TestCase):
|
||||
assert sd.config.safety_checker != (None, None)
|
||||
assert sd.config.feature_extractor != (None, None)
|
||||
|
||||
def test_name_or_path(self):
|
||||
model_path = "hf-internal-testing/tiny-stable-diffusion-torch"
|
||||
sd = DiffusionPipeline.from_pretrained(model_path)
|
||||
|
||||
assert sd.name_or_path == model_path
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
sd.save_pretrained(tmpdirname)
|
||||
sd = DiffusionPipeline.from_pretrained(tmpdirname)
|
||||
|
||||
assert sd.name_or_path == tmpdirname
|
||||
|
||||
def test_warning_no_variant_available(self):
|
||||
variant = "fp16"
|
||||
with self.assertWarns(FutureWarning) as warning_context:
|
||||
|
||||
@@ -17,7 +17,7 @@ from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
|
||||
from diffusers.utils.testing_utils import require_torch, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger, require_torch, torch_device
|
||||
|
||||
|
||||
def to_np(tensor):
|
||||
@@ -298,9 +298,19 @@ class PipelineTesterMixin:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
|
||||
logger.setLevel(diffusers.logging.INFO)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||
|
||||
for name in pipe_loaded.components.keys():
|
||||
if name not in pipe_loaded._optional_components:
|
||||
assert name in str(cap_logger)
|
||||
|
||||
pipe_loaded.to(torch_device)
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user