diff --git a/setup.py b/setup.py index dbc0f86dc1..81dc5a3973 100644 --- a/setup.py +++ b/setup.py @@ -251,7 +251,7 @@ else: extras["flax"] = deps_list("jax", "jaxlib", "flax") -extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] +extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"] install_requires = [ deps["importlib_metadata"],