[Flax] Fix BigBird (#13380)

* finish

* finish
This commit is contained in:
Patrick von Platen 2021-09-01 18:33:54 +02:00 committed by GitHub
parent ecd5397106
commit 4475f1dc2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2029,6 +2029,7 @@ class FlaxBigBirdForQuestionAnswering(FlaxBigBirdPreTrainedModel):
if token_type_ids is None:
token_type_ids = (~logits_mask).astype("i4")
logits_mask = jnp.expand_dims(logits_mask, axis=2)
logits_mask = jax.ops.index_update(logits_mask, jax.ops.index[:, 0], False)
# init input tensors if not passed
if token_type_ids is None: