mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
@@ -482,29 +482,6 @@ class FlaxModelMixin:
|
||||
" training."
|
||||
)
|
||||
|
||||
# dictionary of key: dtypes for the model params
|
||||
param_dtypes = jax.tree_map(lambda x: x.dtype, state)
|
||||
# extract keys of parameters not in jnp.float32
|
||||
fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16]
|
||||
bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16]
|
||||
|
||||
# raise a warning if any of the parameters are not in jnp.float32
|
||||
if len(fp16_params) > 0:
|
||||
logger.warning(
|
||||
f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from "
|
||||
f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n"
|
||||
"You should probably UPCAST the model weights to float32 if this was not intended. "
|
||||
"See [`~ModelMixin.to_fp32`] for further information on how to do this."
|
||||
)
|
||||
|
||||
if len(bf16_params) > 0:
|
||||
logger.warning(
|
||||
f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
|
||||
f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
|
||||
"You should probably UPCAST the model weights to float32 if this was not intended. "
|
||||
"See [`~ModelMixin.to_fp32`] for further information on how to do this."
|
||||
)
|
||||
|
||||
return model, unflatten_dict(state)
|
||||
|
||||
def save_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user