From dec18c86321de39ad853e187826a70c435467c16 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 21 Oct 2022 13:13:36 +0200 Subject: [PATCH] [Flax] dont warn for bf16 weights (#923) dont warn for bf16 weights --- src/diffusers/modeling_flax_utils.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 6cb30a26f7..e31a9b7b80 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -482,29 +482,6 @@ class FlaxModelMixin: " training." ) - # dictionary of key: dtypes for the model params - param_dtypes = jax.tree_map(lambda x: x.dtype, state) - # extract keys of parameters not in jnp.float32 - fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16] - bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16] - - # raise a warning if any of the parameters are not in jnp.float32 - if len(fp16_params) > 0: - logger.warning( - f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from " - f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n" - "You should probably UPCAST the model weights to float32 if this was not intended. " - "See [`~ModelMixin.to_fp32`] for further information on how to do this." - ) - - if len(bf16_params) > 0: - logger.warning( - f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from " - f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n" - "You should probably UPCAST the model weights to float32 if this was not intended. " - "See [`~ModelMixin.to_fp32`] for further information on how to do this." - ) - return model, unflatten_dict(state) def save_pretrained(