flax.linen.apply takes state as the first param, followed by the input (#12510)

This commit is contained in:
Navjot 2021-07-05 07:03:14 -07:00 committed by GitHub
parent f1c81d6b92
commit eceb1042c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -621,7 +621,7 @@ state = model_flax.init(rng, dummy_input_ids)
and then we can do the forward pass.
```python
sequences = model_flax.apply(input_ids, state)
sequences = model_flax.apply(state, input_ids)
```
Visually, the forward pass would now be represented as passing all tensors required for the computation to the model's object: