mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
optimize the code
This commit is contained in:
parent
58852ee6c9
commit
fa6506b7c0
@ -930,8 +930,8 @@ class ModernBertModel(ModernBertPreTrainedModel):
|
||||
)
|
||||
|
||||
# Expand the attention mask
|
||||
if self.config._attn_implementation == "sdpa" and attention_mask.dim() == 2:
|
||||
# Expand the attention mask for SDPA.
|
||||
if attention_mask.dim() == 2:
|
||||
# Expand the attention mask
|
||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||
global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype, tgt_len=input_shape[1])
|
||||
else:
|
||||
|
@ -1060,8 +1060,8 @@ class ModernBertModel(ModernBertPreTrainedModel):
|
||||
)
|
||||
|
||||
# Expand the attention mask
|
||||
if self.config._attn_implementation == "sdpa" and attention_mask.dim() == 2:
|
||||
# Expand the attention mask for SDPA.
|
||||
if attention_mask.dim() == 2:
|
||||
# Expand the attention mask
|
||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||
global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype, tgt_len=input_shape[1])
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user