datasets >= 1.1.3 jax>=0.2.8 jaxlib>=0.1.59 git+https://github.com/google/flax.git git+https://github.com/deepmind/optax.git