mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Link official Cloud TPU JAX docs (#11892)
This commit is contained in:
parent
1530384e5b
commit
2df546918e
@ -28,13 +28,13 @@ efficient vectorization), and `pjit` (for automatically sharded model parallelis
|
||||
computing per-example gradients is simply `vmap(grad(f))`.
|
||||
|
||||
[Flax](https://github.com/google/flax) builds on top of JAX with an ergonomic
|
||||
module abstraction using Python dataclasses that leads to concise and explicit code. Flax's "lifted" JAX transformations (e.g. `vmap`, `remat`) allow you to nest JAX transformation and modules in any way you wish. Flax is the most widely used JAX library, with [129 dependent projects](https://github.com/google/flax/network/dependents?package_id=UGFja2FnZS01MjEyMjA2MA%3D%3D) as of May 2021. It is also the library underlying all of the official Cloud TPU JAX examples. (TODO: Add link once it's there.)
|
||||
module abstraction using Python dataclasses that leads to concise and explicit code. Flax's "lifted" JAX transformations (e.g. `vmap`, `remat`) allow you to nest JAX transformation and modules in any way you wish. Flax is the most widely used JAX library, with [129 dependent projects](https://github.com/google/flax/network/dependents?package_id=UGFja2FnZS01MjEyMjA2MA%3D%3D) as of May 2021. It is also the library underlying all of the official Cloud TPU JAX examples.
|
||||
|
||||
## Running on Cloud TPU
|
||||
|
||||
All of our JAX/Flax models are designed to run efficiently on Google
|
||||
Cloud TPUs. Here is a guide for running jobs on Google Cloud TPU.
|
||||
(TODO: Add a link to the Cloud TPU JAX getting started guide once it's public)
|
||||
Cloud TPUs. Here is [a guide for running JAX on Google Cloud TPU](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm).
|
||||
|
||||
Each example README contains more details on the specific model and training
|
||||
procedure.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user