Update README.md

This commit is contained in:
Sanchit Gandhi 2022-03-04 09:58:45 +01:00 committed by GitHub
parent a6e3b17981
commit b71474895d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -714,7 +714,7 @@ Now the `FlaxMLPModel` will have a similar interface as PyTorch or Tensorflow mo
So the important point to remember is that the `model` is not an instance of `nn.Module`; it's an abstract class, like a container that holds a Flax module, its parameters and provides convenient methods for initialization and forward pass. The key take-away here is that an instance of `FlaxMLPModel` is very much stateful now since it holds all the model parameters, whereas the underlying Flax module `FlaxMLPModule` is still stateless. Now to make `FlaxMLPModel` fully compliant with JAX transformations, it is always possible to pass the parameters to `FlaxMLPModel` as well to make it stateless and easier to work with during training. Feel free to take a look at the code to see how exactly this is implemented for ex. [`modeling_flax_bert.py`](https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/modeling_flax_bert.py#L536)
Another significant difference between Flax and PyTorch models is that, we can pass the `labels` directly to PyTorch's forward pass to compute the loss, whereas Flax models never accept `labels` as an input argument. In PyTorch, gradient backpropagation is performed by simply calling `.backward()` on the computed loss which makes it very handy for the user to be able to pass the `labels`. In Flax however, gradient backpropagation cannot be done by simply calling `.backward()` on the loss output, but the loss function itself has to be transformed by `jax.grad` or `jax.value_and_grad` to return the gradients of all parameters. This transformation cannot happen under-the-hood when one passes the `labels` to Flax's forward function, so that in Flax, we simply don't allow `labels` to be passed by design and force the user to implement the loss function her-/himself. As a conclusion, you will see that all training-related code is decoupled from the modeling code and always defined in the training scripts themselves.
Another significant difference between Flax and PyTorch models is that, we can pass the `labels` directly to PyTorch's forward pass to compute the loss, whereas Flax models never accept `labels` as an input argument. In PyTorch, gradient backpropagation is performed by simply calling `.backward()` on the computed loss which makes it very handy for the user to be able to pass the `labels`. In Flax however, gradient backpropagation cannot be done by simply calling `.backward()` on the loss output, but the loss function itself has to be transformed by `jax.grad` or `jax.value_and_grad` to return the gradients of all parameters. This transformation cannot happen under-the-hood when one passes the `labels` to Flax's forward function, so that in Flax, we simply don't allow `labels` to be passed by design and force the user to implement the loss function oneself. As a conclusion, you will see that all training-related code is decoupled from the modeling code and always defined in the training scripts themselves.
### **How to use flax models and example scripts**
@ -769,7 +769,7 @@ model = FlaxGPT2ForCausalLM(config)
```
As explained above we don't compute the loss inside the model, but rather in the task-specific training script.
For demonstration purposes, we write a pseude training script for causal lanuage modeling in the following.
For demonstration purposes, we write a pseudo training script for causal language modeling in the following.
```python
from flax.training.common_utils import onehot