cast bool tensor to long for pytorch < 1.3

This commit is contained in:
Rémi Louf 2019-11-12 17:59:34 +01:00 committed by Julien Chaumond
parent 9f75565ea8
commit 4d18199902

View File

@ -675,6 +675,7 @@ class BertModel(BertPreTrainedModel):
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
causal_mask = causal_mask.to(torch.long) # not converting to long will cause errors with pytorch version < 1.3
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]