mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Add init_weights method to FlaxMixin (#513)
* Add `init_weights` method to `FlaxMixin` * Rn `random_state` -> `shape_state` * `PRNGKey(0)` for `jax.eval_shape` * No allow mismatched sizes * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * docstring diffusers Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -20,7 +20,7 @@ from typing import Any, Dict, Union
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import msgpack.exceptions
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.serialization import from_bytes, to_bytes
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -183,6 +183,9 @@ class FlaxModelMixin:
|
||||
```"""
|
||||
return self._cast_floating_to(params, jnp.float16, mask)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey) -> Dict:
|
||||
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -227,10 +230,6 @@ class FlaxModelMixin:
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
|
||||
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
|
||||
checkpoint with 3 labels).
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
@@ -394,6 +393,72 @@ class FlaxModelMixin:
|
||||
# flatten dicts
|
||||
state = flatten_dict(state)
|
||||
|
||||
params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
|
||||
required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
|
||||
|
||||
shape_state = flatten_dict(unfreeze(params_shape_tree))
|
||||
|
||||
missing_keys = required_params - set(state.keys())
|
||||
unexpected_keys = set(state.keys()) - required_params
|
||||
|
||||
if missing_keys:
|
||||
logger.warning(
|
||||
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
|
||||
"Make sure to call model.init_weights to initialize the missing weights."
|
||||
)
|
||||
cls._missing_keys = missing_keys
|
||||
|
||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||
# matching the weights in the model.
|
||||
mismatched_keys = []
|
||||
for key in state.keys():
|
||||
if key in shape_state and state[key].shape != shape_state[key].shape:
|
||||
raise ValueError(
|
||||
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
|
||||
f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
|
||||
)
|
||||
|
||||
# remove unexpected keys to not be saved again
|
||||
for unexpected_key in unexpected_keys:
|
||||
del state[unexpected_key]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
||||
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
||||
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
|
||||
" with another architecture."
|
||||
)
|
||||
else:
|
||||
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
||||
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||
)
|
||||
elif len(mismatched_keys) == 0:
|
||||
logger.info(
|
||||
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
|
||||
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
|
||||
" training."
|
||||
)
|
||||
if len(mismatched_keys) > 0:
|
||||
mismatched_warning = "\n".join(
|
||||
[
|
||||
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
||||
for key, shape1, shape2 in mismatched_keys
|
||||
]
|
||||
)
|
||||
logger.warning(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
||||
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
|
||||
" to use it for predictions and inference."
|
||||
)
|
||||
|
||||
# dictionary of key: dtypes for the model params
|
||||
param_dtypes = jax.tree_map(lambda x: x.dtype, state)
|
||||
# extract keys of parameters not in jnp.float32
|
||||
|
||||
Reference in New Issue
Block a user