mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
cast bool tensor to long for pytorch < 1.3
This commit is contained in:
parent
9f75565ea8
commit
4d18199902
@ -675,6 +675,7 @@ class BertModel(BertPreTrainedModel):
|
||||
batch_size, seq_length = input_shape
|
||||
seq_ids = torch.arange(seq_length, device=device)
|
||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||
causal_mask = causal_mask.to(torch.long) # not converting to long will cause errors with pytorch version < 1.3
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||
else:
|
||||
extended_attention_mask = attention_mask[:, None, None, :]
|
||||
|
Loading…
Reference in New Issue
Block a user