mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Make sliding window size inclusive in eager attention (#29519)
* Make sliding window size inclusive in eager attention * Fix tests
This commit is contained in:
parent
f386c51ad9
commit
608fa5496c
@ -164,10 +164,10 @@ class AttentionMaskConverter:
|
||||
|
||||
# add lower triangular sliding window mask if necessary
|
||||
if sliding_window is not None:
|
||||
diagonal = past_key_values_length - sliding_window + 1
|
||||
diagonal = past_key_values_length - sliding_window - 1
|
||||
|
||||
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
|
||||
mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
|
||||
context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
|
||||
mask.masked_fill_(context_mask, torch.finfo(dtype).min)
|
||||
|
||||
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
||||
|
||||
|
@ -1673,7 +1673,7 @@ class AttentionMaskTester(unittest.TestCase):
|
||||
def compute_num_context_mask(self, kv_len, context, q_len):
|
||||
# This function computes the # of attention tokens that are added for
|
||||
# the sliding window
|
||||
c_mask_len = kv_len - context
|
||||
c_mask_len = kv_len - context - 1
|
||||
num_mask_triangle = c_mask_len * (c_mask_len + 1) // 2
|
||||
cut_mask_len = max(c_mask_len - q_len, 0)
|
||||
num_cut_mask = cut_mask_len * (cut_mask_len + 1) // 2
|
||||
|
Loading…
Reference in New Issue
Block a user