fix tf led pt test (#9513)

This commit is contained in:
Patrick von Platen 2021-01-11 14:14:48 +01:00 committed by GitHub
parent 1e3c362235
commit 6c8ec2a931
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -166,7 +166,13 @@ def prepare_led_inputs_dict(
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
if decoder_attention_mask is None:
decoder_attention_mask = tf.cast(tf.math.not_equal(decoder_input_ids, config.pad_token_id), tf.int8)
decoder_attention_mask = tf.concat(
[
tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int8),
tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int8),
],
axis=-1,
)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,