mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix LXMERT with DataParallel (#7471)
This commit is contained in:
parent
35e94c68df
commit
886ef35ce6
@ -958,7 +958,7 @@ class LxmertModel(LxmertPreTrainedModel):
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# Process the visual attention mask
|
||||
|
Loading…
Reference in New Issue
Block a user