From bc0392a0cbac301474ef82eed5818d2030a4fc4c Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Fri, 21 Apr 2023 08:01:36 -1000 Subject: [PATCH] make `from_flax` work for controlnet (#3161) fix from_flax Co-authored-by: Patrick von Platen --- src/diffusers/models/modeling_pytorch_flax_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusers/models/modeling_pytorch_flax_utils.py b/src/diffusers/models/modeling_pytorch_flax_utils.py index b368a74ca2..17b521b001 100644 --- a/src/diffusers/models/modeling_pytorch_flax_utils.py +++ b/src/diffusers/models/modeling_pytorch_flax_utils.py @@ -110,6 +110,12 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): .replace("_1", ".1") .replace("_2", ".2") .replace("_3", ".3") + .replace("_4", ".4") + .replace("_5", ".5") + .replace("_6", ".6") + .replace("_7", ".7") + .replace("_8", ".8") + .replace("_9", ".9") ) flax_key = ".".join(flax_key_tuple_array)