mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
[bloom] fix alibi device placement (#18087)
This commit is contained in:
parent
8b332a6a16
commit
ad28ca291b
@ -93,7 +93,7 @@ def attention_mask_func(attention_scores, attention_mask, causal_mask):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16):
|
def build_alibi_tensor(max_seq_len, n_head, device, dtype=torch.bfloat16):
|
||||||
"""
|
"""
|
||||||
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
||||||
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
||||||
@ -129,7 +129,7 @@ def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16):
|
|||||||
arange_tensor = torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0)
|
arange_tensor = torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0)
|
||||||
alibi = slopes * arange_tensor.expand(n_head, -1, -1)
|
alibi = slopes * arange_tensor.expand(n_head, -1, -1)
|
||||||
|
|
||||||
alibi = alibi.to(dtype)
|
alibi = alibi.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
return alibi
|
return alibi
|
||||||
|
|
||||||
@ -147,7 +147,7 @@ def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
|
|||||||
# This usually happens when the inference is done with past_key_values
|
# This usually happens when the inference is done with past_key_values
|
||||||
# In this case we re-create the alibi tensor with the correct sequence length
|
# In this case we re-create the alibi tensor with the correct sequence length
|
||||||
if attention_mask.shape[-1] != alibi.shape[-1]:
|
if attention_mask.shape[-1] != alibi.shape[-1]:
|
||||||
alibi = build_alibi_tensor(attention_mask.shape[-1], num_heads, alibi.dtype).repeat(
|
alibi = build_alibi_tensor(attention_mask.shape[-1], num_heads, alibi.device, alibi.dtype).repeat(
|
||||||
attention_mask.shape[0], 1, 1
|
attention_mask.shape[0], 1, 1
|
||||||
)
|
)
|
||||||
# Get the indexes of the padding tokens
|
# Get the indexes of the padding tokens
|
||||||
@ -156,7 +156,7 @@ def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
|
|||||||
|
|
||||||
# Clone the embeddings - we can detach because the embeddings are not learned
|
# Clone the embeddings - we can detach because the embeddings are not learned
|
||||||
# Get a refence tensor
|
# Get a refence tensor
|
||||||
slice_reference_alibi = build_alibi_tensor(alibi.shape[-1], num_heads, alibi.dtype)
|
slice_reference_alibi = build_alibi_tensor(alibi.shape[-1], num_heads, alibi.device, alibi.dtype)
|
||||||
|
|
||||||
# Loop over the batch where the padding is and replace the alibi tensor by the reference tensor
|
# Loop over the batch where the padding is and replace the alibi tensor by the reference tensor
|
||||||
# Only where you do not have padding. Replace padding tokens by zeros
|
# Only where you do not have padding. Replace padding tokens by zeros
|
||||||
@ -767,7 +767,7 @@ class BloomModel(BloomPreTrainedModel):
|
|||||||
current_sequence_length = hidden_states.shape[1]
|
current_sequence_length = hidden_states.shape[1]
|
||||||
if past_key_values[0] is not None:
|
if past_key_values[0] is not None:
|
||||||
current_sequence_length += past_key_values[0][0].shape[1]
|
current_sequence_length += past_key_values[0][0].shape[1]
|
||||||
alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype)
|
alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.device, hidden_states.dtype)
|
||||||
|
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user