diff --git a/src/diffusers/models/modeling_pytorch_flax_utils.py b/src/diffusers/models/modeling_pytorch_flax_utils.py index b368a74ca2..17b521b001 100644 --- a/src/diffusers/models/modeling_pytorch_flax_utils.py +++ b/src/diffusers/models/modeling_pytorch_flax_utils.py @@ -110,6 +110,12 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): .replace("_1", ".1") .replace("_2", ".2") .replace("_3", ".3") + .replace("_4", ".4") + .replace("_5", ".5") + .replace("_6", ".6") + .replace("_7", ".7") + .replace("_8", ".8") + .replace("_9", ".9") ) flax_key = ".".join(flax_key_tuple_array)