From 4ced8799303b69e5dbe791194289dc08f73017b2 Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 5 Aug 2025 22:48:25 +0530 Subject: [PATCH] update --- src/diffusers/loaders/single_file_model.py | 26 ++++++++++++++-------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 76fefc1260..e1c87476ae 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -156,6 +156,10 @@ SINGLE_FILE_LOADABLE_CLASSES = { } +def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict): + return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys())) + + def _get_single_file_loadable_mapping_class(cls): diffusers_module = importlib.import_module(__name__.split(".")[0]) for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES: @@ -381,19 +385,23 @@ class FromOriginalModelMixin: model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs} diffusers_model_config.update(model_kwargs) - checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs) - diffusers_format_checkpoint = checkpoint_mapping_fn( - config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs - ) - if not diffusers_format_checkpoint: - raise SingleFileComponentError( - f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint." - ) - ctx = init_empty_weights if is_accelerate_available() else nullcontext with ctx(): model = cls.from_config(diffusers_model_config) + checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs) + + if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint): + diffusers_format_checkpoint = checkpoint_mapping_fn( + config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs + ) + else: + diffusers_format_checkpoint = checkpoint + + if not diffusers_format_checkpoint: + raise SingleFileComponentError( + f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint." + ) # Check if `_keep_in_fp32_modules` is not None use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")