mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
replace torch.triu with onnx compatible code (#6929)
This commit is contained in:
parent
ed71c21d6a
commit
9fd11bf1a8
@ -154,9 +154,10 @@ def _prepare_bart_decoder_inputs(
|
||||
if decoder_padding_mask is not None and decoder_padding_mask.shape[1] > 1:
|
||||
# never mask leading token, even if it is pad
|
||||
decoder_padding_mask[:, 0] = decoder_padding_mask[:, 1]
|
||||
causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
|
||||
dtype=causal_mask_dtype, device=decoder_input_ids.device
|
||||
)
|
||||
tmp = fill_with_neg_inf(torch.zeros(tgt_len, tgt_len))
|
||||
mask = torch.arange(tmp.size(-1))
|
||||
tmp.masked_fill_(mask < (mask + 1).view(tmp.size(-1), 1), 0)
|
||||
causal_mask = tmp.to(dtype=causal_mask_dtype, device=decoder_input_ids.device)
|
||||
return decoder_input_ids, decoder_padding_mask, causal_mask
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user