mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix tf led pt test (#9513)
This commit is contained in:
parent
1e3c362235
commit
6c8ec2a931
@ -166,7 +166,13 @@ def prepare_led_inputs_dict(
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = tf.cast(tf.math.not_equal(decoder_input_ids, config.pad_token_id), tf.int8)
|
||||
decoder_attention_mask = tf.concat(
|
||||
[
|
||||
tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int8),
|
||||
tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int8),
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
|
Loading…
Reference in New Issue
Block a user