Prophetnet optimization (#9453)

* Vectorized `ngram_attention_bias` calculation

* updated formatting with black

* Further optimization

* one (last) optimization
This commit is contained in:
guillaume-be 2021-01-07 11:41:58 +01:00 committed by GitHub
parent 28d74872cc
commit 390cf16bc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -171,13 +171,15 @@ def ngram_attention_bias(sequence_length, ngram, device, dtype):
"""
This function computes the bias for the predict stream
"""
bias = torch.ones((ngram, sequence_length, 2 * sequence_length), device=device, dtype=dtype) * float("-inf")
left_block = torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * float("-inf")
right_block = left_block.detach().clone()
# create bias
for stream_idx in range(ngram):
for i in range(sequence_length):
bias[stream_idx, i, sequence_length + i] = 0
bias[stream_idx, i, : max(i - stream_idx, 0) + 1] = 0
return bias
right_block[stream_idx].fill_diagonal_(0, wrap=False)
left_block[stream_idx].triu_(-stream_idx + 1)
left_block[:, :, 0] = 0
return torch.cat([left_block, right_block], dim=2)
def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False):