mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
OPT/BioGPT: Improved attention mask shape exception (#23270)
This commit is contained in:
parent
21741e8c7e
commit
466af1a356
@ -546,6 +546,12 @@ class BioGptModel(BioGptPreTrainedModel):
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
|
||||
elif attention_mask.shape[1] != past_key_values_length + input_shape[1]:
|
||||
raise ValueError(
|
||||
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
|
||||
f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)"
|
||||
)
|
||||
|
||||
# embed positions
|
||||
positions = self.embed_positions(attention_mask, past_key_values_length)
|
||||
|
||||
|
@ -642,6 +642,11 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
||||
elif attention_mask.shape[1] != mask_seq_length:
|
||||
raise ValueError(
|
||||
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
|
||||
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
|
||||
)
|
||||
causal_attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
@ -645,6 +645,15 @@ class TFOPTDecoder(tf.keras.layers.Layer):
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.ones(inputs_embeds.shape[:2], dtype=tf.bool)
|
||||
else:
|
||||
tf.debugging.assert_equal(
|
||||
attention_mask.shape[1],
|
||||
past_key_values_length + input_shape[1],
|
||||
message=(
|
||||
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
|
||||
f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)"
|
||||
),
|
||||
)
|
||||
|
||||
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user