OPT/BioGPT: Improved attention mask shape exception (#23270)

This commit is contained in:
Joao Gante 2023-05-16 13:59:53 +01:00 committed by GitHub
parent 21741e8c7e
commit 466af1a356
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 0 deletions

View File

@ -546,6 +546,12 @@ class BioGptModel(BioGptPreTrainedModel):
if attention_mask is None:
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
elif attention_mask.shape[1] != past_key_values_length + input_shape[1]:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)"
)
# embed positions
positions = self.embed_positions(attention_mask, past_key_values_length)

View File

@ -642,6 +642,11 @@ class OPTDecoder(OPTPreTrainedModel):
# embed positions
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)
causal_attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)

View File

@ -645,6 +645,15 @@ class TFOPTDecoder(tf.keras.layers.Layer):
if attention_mask is None:
attention_mask = tf.ones(inputs_embeds.shape[:2], dtype=tf.bool)
else:
tf.debugging.assert_equal(
attention_mask.shape[1],
past_key_values_length + input_shape[1],
message=(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)"
),
)
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)