fix wrong 'cls' masking for bigbird qa model output (#13143)

This commit is contained in:
donggyukimc 2021-09-01 21:03:16 +09:00 committed by GitHub
parent 7a26307e31
commit ba1b3db709
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2987,6 +2987,7 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
if token_type_ids is None:
token_type_ids = (~logits_mask).long()
logits_mask = logits_mask
logits_mask[:, 0] = False
logits_mask.unsqueeze_(2)
outputs = self.bert(