mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Remove redundant nn.log_softmax
in run_flax_glue.py
(#11920)
* Remove redundant `nn.log_softmax` in `run_flax_glue.py` `optax.softmax_cross_entropy` expects unnormalized logits, and so it already calls `nn.log_softmax`, so I believe it is not needed here. `nn.log_softmax` is idempotent so mathematically it shouldn't have made a difference. * Remove unused 'flax.linen' import
This commit is contained in:
parent
fb60c309c6
commit
1ab147d648
@ -29,7 +29,6 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import transformers
|
||||
from flax import linen as nn
|
||||
from flax import struct, traverse_util
|
||||
from flax.jax_utils import replicate, unreplicate
|
||||
from flax.metrics import tensorboard
|
||||
@ -202,7 +201,6 @@ def create_train_state(
|
||||
else: # Classification.
|
||||
|
||||
def cross_entropy_loss(logits, labels):
|
||||
logits = nn.log_softmax(logits)
|
||||
xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels))
|
||||
return jnp.mean(xentropy)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user