mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 22:38:58 +06:00
flax.linen.apply takes state as the first param, followed by the input (#12510)
This commit is contained in:
parent
f1c81d6b92
commit
eceb1042c1
@ -621,7 +621,7 @@ state = model_flax.init(rng, dummy_input_ids)
|
||||
and then we can do the forward pass.
|
||||
|
||||
```python
|
||||
sequences = model_flax.apply(input_ids, state)
|
||||
sequences = model_flax.apply(state, input_ids)
|
||||
```
|
||||
|
||||
Visually, the forward pass would now be represented as passing all tensors required for the computation to the model's object:
|
||||
|
Loading…
Reference in New Issue
Block a user