optimize the code

This commit is contained in:
BUI Van Tuan 2025-06-08 15:34:08 +02:00
parent 58852ee6c9
commit fa6506b7c0
2 changed files with 4 additions and 4 deletions

View File

@ -930,8 +930,8 @@ class ModernBertModel(ModernBertPreTrainedModel):
)
# Expand the attention mask
if self.config._attn_implementation == "sdpa" and attention_mask.dim() == 2:
# Expand the attention mask for SDPA.
if attention_mask.dim() == 2:
# Expand the attention mask
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype, tgt_len=input_shape[1])
else:

View File

@ -1060,8 +1060,8 @@ class ModernBertModel(ModernBertPreTrainedModel):
)
# Expand the attention mask
if self.config._attn_implementation == "sdpa" and attention_mask.dim() == 2:
# Expand the attention mask for SDPA.
if attention_mask.dim() == 2:
# Expand the attention mask
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype, tgt_len=input_shape[1])
else: