From b41704229119eecbd5c4f2789710c784fd8e2335 Mon Sep 17 00:00:00 2001 From: apolinario Date: Tue, 13 Dec 2022 12:44:20 +0100 Subject: [PATCH] Fix wrong type checking in `convert_diffusers_to_original_stable_diffusion.py` (#1681) * Fix type checking remainders * Remove IS_V20_MODEL flag always being True Co-authored-by: apolinario --- scripts/convert_diffusers_to_original_stable_diffusion.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py index e9954a5588..11b4b873e7 100644 --- a/scripts/convert_diffusers_to_original_stable_diffusion.py +++ b/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -209,7 +209,7 @@ textenc_pattern = re.compile("|".join(protected.keys())) code2idx = {"q": 0, "k": 1, "v": 2} -def convert_text_enc_state_dict_v20(text_enc_dict: dict[str, torch.Tensor]): +def convert_text_enc_state_dict_v20(text_enc_dict): new_state_dict = {} capture_qkv_weight = {} capture_qkv_bias = {} @@ -256,12 +256,10 @@ def convert_text_enc_state_dict_v20(text_enc_dict: dict[str, torch.Tensor]): return new_state_dict -def convert_text_enc_state_dict(text_enc_dict: dict[str, torch.Tensor]): +def convert_text_enc_state_dict(text_enc_dict): return text_enc_dict -IS_V20_MODEL = True - if __name__ == "__main__": parser = argparse.ArgumentParser()