replace torch.triu with onnx compatible code (#6929)

This commit is contained in:
Henry Dashwood 2020-09-09 09:56:40 +01:00 committed by GitHub
parent ed71c21d6a
commit 9fd11bf1a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -154,9 +154,10 @@ def _prepare_bart_decoder_inputs(
if decoder_padding_mask is not None and decoder_padding_mask.shape[1] > 1:
# never mask leading token, even if it is pad
decoder_padding_mask[:, 0] = decoder_padding_mask[:, 1]
causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
dtype=causal_mask_dtype, device=decoder_input_ids.device
)
tmp = fill_with_neg_inf(torch.zeros(tgt_len, tgt_len))
mask = torch.arange(tmp.size(-1))
tmp.masked_fill_(mask < (mask + 1).view(tmp.size(-1), 1), 0)
causal_mask = tmp.to(dtype=causal_mask_dtype, device=decoder_input_ids.device)
return decoder_input_ids, decoder_padding_mask, causal_mask