diff --git a/examples/flax/README.md b/examples/flax/README.md index 978d34e2f7d..039bf9de18c 100644 --- a/examples/flax/README.md +++ b/examples/flax/README.md @@ -28,13 +28,13 @@ efficient vectorization), and `pjit` (for automatically sharded model parallelis computing per-example gradients is simply `vmap(grad(f))`. [Flax](https://github.com/google/flax) builds on top of JAX with an ergonomic -module abstraction using Python dataclasses that leads to concise and explicit code. Flax's "lifted" JAX transformations (e.g. `vmap`, `remat`) allow you to nest JAX transformation and modules in any way you wish. Flax is the most widely used JAX library, with [129 dependent projects](https://github.com/google/flax/network/dependents?package_id=UGFja2FnZS01MjEyMjA2MA%3D%3D) as of May 2021. It is also the library underlying all of the official Cloud TPU JAX examples. (TODO: Add link once it's there.) +module abstraction using Python dataclasses that leads to concise and explicit code. Flax's "lifted" JAX transformations (e.g. `vmap`, `remat`) allow you to nest JAX transformation and modules in any way you wish. Flax is the most widely used JAX library, with [129 dependent projects](https://github.com/google/flax/network/dependents?package_id=UGFja2FnZS01MjEyMjA2MA%3D%3D) as of May 2021. It is also the library underlying all of the official Cloud TPU JAX examples. ## Running on Cloud TPU All of our JAX/Flax models are designed to run efficiently on Google -Cloud TPUs. Here is a guide for running jobs on Google Cloud TPU. -(TODO: Add a link to the Cloud TPU JAX getting started guide once it's public) +Cloud TPUs. Here is [a guide for running JAX on Google Cloud TPU](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm). + Each example README contains more details on the specific model and training procedure.