Falcon: removed unused function (#28605)

This commit is contained in:
Joao Gante 2024-01-27 15:52:59 +00:00 committed by GitHub
parent de13a951b3
commit a28a76996c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -214,17 +214,6 @@ class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def _prepare_4d_attention_mask(mask: torch.Tensor, past_key_values_length: int) -> torch.BoolTensor:
"""
Expands attention_mask from `[batch_size, seq_length]` to `[batch_size, 1, seq_length, seq_length + past_length]`.
"""
batch_size, total_length = mask.shape
seq_length = total_length - past_key_values_length if past_key_values_length is not None else total_length
expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
return expanded_mask.expand(batch_size, 1, seq_length, total_length)
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))