mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Merge pull request #2068 from huggingface/fix-2042
Nicer error message when Bert's input is missing batch size
This commit is contained in:
commit
fc1bb1f867
@ -667,11 +667,10 @@ class BertModel(BertPreTrainedModel):
|
|||||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
if attention_mask.dim() == 3:
|
if attention_mask.dim() == 3:
|
||||||
extended_attention_mask = attention_mask[:, None, :, :]
|
extended_attention_mask = attention_mask[:, None, :, :]
|
||||||
|
elif attention_mask.dim() == 2:
|
||||||
# Provided a padding mask of dimensions [batch_size, seq_length]
|
# Provided a padding mask of dimensions [batch_size, seq_length]
|
||||||
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
if attention_mask.dim() == 2:
|
|
||||||
if self.config.is_decoder:
|
if self.config.is_decoder:
|
||||||
batch_size, seq_length = input_shape
|
batch_size, seq_length = input_shape
|
||||||
seq_ids = torch.arange(seq_length, device=device)
|
seq_ids = torch.arange(seq_length, device=device)
|
||||||
@ -679,6 +678,8 @@ class BertModel(BertPreTrainedModel):
|
|||||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||||
else:
|
else:
|
||||||
extended_attention_mask = attention_mask[:, None, None, :]
|
extended_attention_mask = attention_mask[:, None, None, :]
|
||||||
|
else:
|
||||||
|
raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(input_shape, attention_mask.shape))
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
@ -696,8 +697,11 @@ class BertModel(BertPreTrainedModel):
|
|||||||
|
|
||||||
if encoder_attention_mask.dim() == 3:
|
if encoder_attention_mask.dim() == 3:
|
||||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||||
if encoder_attention_mask.dim() == 2:
|
elif encoder_attention_mask.dim() == 2:
|
||||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||||
|
else:
|
||||||
|
raise ValueError("Wrong shape for input_ids (shape {}) or encoder_attention_mask (shape {})".format(input_shape,
|
||||||
|
encoder_attention_mask.shape))
|
||||||
|
|
||||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||||
|
Loading…
Reference in New Issue
Block a user