mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
make build_mpt_alibi_tensor a method of MptModel so that deepspeed co… (#25193)
make build_mpt_alibi_tensor a method of MptModel so that deepspeed could override it to make autoTP work Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
0fd8d2aa2c
commit
4033ea7167
@ -413,6 +413,9 @@ class MptModel(MptPreTrainedModel):
|
||||
def get_input_embeddings(self):
|
||||
return self.wte
|
||||
|
||||
def build_mpt_alibi_tensor(self, num_heads, sequence_length, alibi_bias_max=8, device=None):
|
||||
return build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max, device)
|
||||
|
||||
def _prepare_attn_mask(
|
||||
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
|
||||
) -> torch.BoolTensor:
|
||||
@ -507,7 +510,7 @@ class MptModel(MptPreTrainedModel):
|
||||
else:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
|
||||
alibi = build_mpt_alibi_tensor(self.num_heads, self.config.max_seq_len, device=hidden_states.device)
|
||||
alibi = self.build_mpt_alibi_tensor(self.num_heads, self.config.max_seq_len, device=hidden_states.device)
|
||||
|
||||
causal_mask = self._prepare_attn_mask(
|
||||
attention_mask,
|
||||
|
Loading…
Reference in New Issue
Block a user