diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 7c9e4e46a5..19f58fd816 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -456,6 +456,9 @@ def flax_register_to_config(cls): # Make sure init_kwargs override default kwargs new_kwargs = {**default_kwargs, **init_kwargs} + # dtype should be part of `init_kwargs`, but not `new_kwargs` + if "dtype" in new_kwargs: + new_kwargs.pop("dtype") # Get positional arguments aligned with kwargs for i, arg in enumerate(args):