From ea6938aea589b034c2320964bf066ba6dd33b12c Mon Sep 17 00:00:00 2001 From: "Donald.Lee" <47619881+February24-Lee@users.noreply.github.com> Date: Thu, 27 Jun 2024 02:00:49 +0900 Subject: [PATCH] Fix: unet save_attn_procs at UNet2DconditionLoadersMixin (#8699) * fix: unet save_attn_procs at custom diffusion * style: recover unchanaged parts(max line length 119) / mod: add condition * style: recover unchanaged parts(max line length 119) --------- Co-authored-by: Sayak Paul --- src/diffusers/loaders/unet.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 22a064c9f4..58c9c0e60d 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -457,6 +457,15 @@ class UNet2DConditionLoadersMixin: ) if is_custom_diffusion: state_dict = self._get_custom_diffusion_state_dict() + if save_function is None and safe_serialization: + # safetensors does not support saving dicts with non-tensor values + empty_state_dict = {k: v for k, v in state_dict.items() if not isinstance(v, torch.Tensor)} + if len(empty_state_dict) > 0: + logger.warning( + f"Safetensors does not support saving dicts with non-tensor values. " + f"The following keys will be ignored: {empty_state_dict.keys()}" + ) + state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)} else: if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")