diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py index e9954a5588..11b4b873e7 100644 --- a/scripts/convert_diffusers_to_original_stable_diffusion.py +++ b/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -209,7 +209,7 @@ textenc_pattern = re.compile("|".join(protected.keys())) code2idx = {"q": 0, "k": 1, "v": 2} -def convert_text_enc_state_dict_v20(text_enc_dict: dict[str, torch.Tensor]): +def convert_text_enc_state_dict_v20(text_enc_dict): new_state_dict = {} capture_qkv_weight = {} capture_qkv_bias = {} @@ -256,12 +256,10 @@ def convert_text_enc_state_dict_v20(text_enc_dict: dict[str, torch.Tensor]): return new_state_dict -def convert_text_enc_state_dict(text_enc_dict: dict[str, torch.Tensor]): +def convert_text_enc_state_dict(text_enc_dict): return text_enc_dict -IS_V20_MODEL = True - if __name__ == "__main__": parser = argparse.ArgumentParser()