From fa6506b7c0f4eb85e114a938f44eac8ae3be99a9 Mon Sep 17 00:00:00 2001 From: BUI Van Tuan Date: Sun, 8 Jun 2025 15:34:08 +0200 Subject: [PATCH] optimize the code --- src/transformers/models/modernbert/modeling_modernbert.py | 4 ++-- src/transformers/models/modernbert/modular_modernbert.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index b2e52f98e43..4154109ddc6 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -930,8 +930,8 @@ class ModernBertModel(ModernBertPreTrainedModel): ) # Expand the attention mask - if self.config._attn_implementation == "sdpa" and attention_mask.dim() == 2: - # Expand the attention mask for SDPA. + if attention_mask.dim() == 2: + # Expand the attention mask # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype, tgt_len=input_shape[1]) else: diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 00a242bbe01..a31c591a7fa 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -1060,8 +1060,8 @@ class ModernBertModel(ModernBertPreTrainedModel): ) # Expand the attention mask - if self.config._attn_implementation == "sdpa" and attention_mask.dim() == 2: - # Expand the attention mask for SDPA. + if attention_mask.dim() == 2: + # Expand the attention mask # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype, tgt_len=input_shape[1]) else: