Fix LXMERT with DataParallel (#7471)

This commit is contained in:
Lysandre Debut 2020-09-30 12:41:24 +02:00 committed by GitHub
parent 35e94c68df
commit 886ef35ce6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -958,7 +958,7 @@ class LxmertModel(LxmertPreTrainedModel):
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# Process the visual attention mask