mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Prophetnet optimization (#9453)
* Vectorized `ngram_attention_bias` calculation * updated formatting with black * Further optimization * one (last) optimization
This commit is contained in:
parent
28d74872cc
commit
390cf16bc8
@ -171,13 +171,15 @@ def ngram_attention_bias(sequence_length, ngram, device, dtype):
|
||||
"""
|
||||
This function computes the bias for the predict stream
|
||||
"""
|
||||
bias = torch.ones((ngram, sequence_length, 2 * sequence_length), device=device, dtype=dtype) * float("-inf")
|
||||
left_block = torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * float("-inf")
|
||||
right_block = left_block.detach().clone()
|
||||
# create bias
|
||||
for stream_idx in range(ngram):
|
||||
for i in range(sequence_length):
|
||||
bias[stream_idx, i, sequence_length + i] = 0
|
||||
bias[stream_idx, i, : max(i - stream_idx, 0) + 1] = 0
|
||||
return bias
|
||||
right_block[stream_idx].fill_diagonal_(0, wrap=False)
|
||||
left_block[stream_idx].triu_(-stream_idx + 1)
|
||||
|
||||
left_block[:, :, 0] = 0
|
||||
return torch.cat([left_block, right_block], dim=2)
|
||||
|
||||
|
||||
def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False):
|
||||
|
Loading…
Reference in New Issue
Block a user