Remove redundant nn.log_softmax in run_flax_glue.py (#11920)

* Remove redundant `nn.log_softmax` in `run_flax_glue.py`

`optax.softmax_cross_entropy` expects unnormalized logits, and so it already calls `nn.log_softmax`, so I believe it is not needed here. `nn.log_softmax` is idempotent so mathematically it shouldn't have made a difference.

* Remove unused 'flax.linen' import
This commit is contained in:
Nicholas Vadivelu 2021-05-31 10:29:04 -04:00 committed by GitHub
parent fb60c309c6
commit 1ab147d648
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -29,7 +29,6 @@ import jax
import jax.numpy as jnp
import optax
import transformers
from flax import linen as nn
from flax import struct, traverse_util
from flax.jax_utils import replicate, unreplicate
from flax.metrics import tensorboard
@ -202,7 +201,6 @@ def create_train_state(
else: # Classification.
def cross_entropy_loss(logits, labels):
logits = nn.log_softmax(logits)
xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels))
return jnp.mean(xentropy)