Fix Tensorflow T5 with int64 input (#13479)

* Fix Tensorflow T5 with int64 input

* Style pass
This commit is contained in:
Matt 2021-09-08 15:06:04 +01:00 committed by GitHub
parent 361b6df36a
commit 707105290b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -874,16 +874,21 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.cast(start_tokens, input_ids.dtype) # Ensure compatible dtypes for concatenation
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
shifted_input_ids == -100,
tf.cast(tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids.dtype),
shifted_input_ids,
)
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
assert_gte0 = tf.debugging.assert_greater_equal(
shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype)
)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):