Link official Cloud TPU JAX docs (#11892)

This commit is contained in:
Avital Oliver 2021-05-26 21:44:40 +02:00 committed by GitHub
parent 1530384e5b
commit 2df546918e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -28,13 +28,13 @@ efficient vectorization), and `pjit` (for automatically sharded model parallelis
computing per-example gradients is simply `vmap(grad(f))`.
[Flax](https://github.com/google/flax) builds on top of JAX with an ergonomic
module abstraction using Python dataclasses that leads to concise and explicit code. Flax's "lifted" JAX transformations (e.g. `vmap`, `remat`) allow you to nest JAX transformation and modules in any way you wish. Flax is the most widely used JAX library, with [129 dependent projects](https://github.com/google/flax/network/dependents?package_id=UGFja2FnZS01MjEyMjA2MA%3D%3D) as of May 2021. It is also the library underlying all of the official Cloud TPU JAX examples. (TODO: Add link once it's there.)
module abstraction using Python dataclasses that leads to concise and explicit code. Flax's "lifted" JAX transformations (e.g. `vmap`, `remat`) allow you to nest JAX transformation and modules in any way you wish. Flax is the most widely used JAX library, with [129 dependent projects](https://github.com/google/flax/network/dependents?package_id=UGFja2FnZS01MjEyMjA2MA%3D%3D) as of May 2021. It is also the library underlying all of the official Cloud TPU JAX examples.
## Running on Cloud TPU
All of our JAX/Flax models are designed to run efficiently on Google
Cloud TPUs. Here is a guide for running jobs on Google Cloud TPU.
(TODO: Add a link to the Cloud TPU JAX getting started guide once it's public)
Cloud TPUs. Here is [a guide for running JAX on Google Cloud TPU](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm).
Each example README contains more details on the specific model and training
procedure.