1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[Flax] dont warn for bf16 weights (#923)

dont warn for bf16 weights
This commit is contained in:
Suraj Patil
2022-10-21 13:13:36 +02:00
committed by GitHub
parent 25dfd0f8dc
commit dec18c8632

View File

@@ -482,29 +482,6 @@ class FlaxModelMixin:
" training."
)
# dictionary of key: dtypes for the model params
param_dtypes = jax.tree_map(lambda x: x.dtype, state)
# extract keys of parameters not in jnp.float32
fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16]
bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16]
# raise a warning if any of the parameters are not in jnp.float32
if len(fp16_params) > 0:
logger.warning(
f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from "
f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n"
"You should probably UPCAST the model weights to float32 if this was not intended. "
"See [`~ModelMixin.to_fp32`] for further information on how to do this."
)
if len(bf16_params) > 0:
logger.warning(
f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
"You should probably UPCAST the model weights to float32 if this was not intended. "
"See [`~ModelMixin.to_fp32`] for further information on how to do this."
)
return model, unflatten_dict(state)
def save_pretrained(