diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 5f92fb96a3e..1ee3e3f097c 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -667,11 +667,10 @@ class BertModel(BertPreTrainedModel): # ourselves in which case we just need to make it broadcastable to all heads. if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] - - # Provided a padding mask of dimensions [batch_size, seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] - if attention_mask.dim() == 2: + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder: batch_size, seq_length = input_shape seq_ids = torch.arange(seq_length, device=device) @@ -679,6 +678,8 @@ class BertModel(BertPreTrainedModel): extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] else: extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(input_shape, attention_mask.shape)) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for @@ -696,8 +697,11 @@ class BertModel(BertPreTrainedModel): if encoder_attention_mask.dim() == 3: encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - if encoder_attention_mask.dim() == 2: + elif encoder_attention_mask.dim() == 2: encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + else: + raise ValueError("Wrong shape for input_ids (shape {}) or encoder_attention_mask (shape {})".format(input_shape, + encoder_attention_mask.shape)) encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0