mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix Tensorflow T5 with int64 input (#13479)
* Fix Tensorflow T5 with int64 input * Style pass
This commit is contained in:
parent
361b6df36a
commit
707105290b
@ -874,16 +874,21 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
|
||||
), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
|
||||
|
||||
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
|
||||
start_tokens = tf.cast(start_tokens, input_ids.dtype) # Ensure compatible dtypes for concatenation
|
||||
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
|
||||
|
||||
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
shifted_input_ids = tf.where(
|
||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
shifted_input_ids == -100,
|
||||
tf.cast(tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids.dtype),
|
||||
shifted_input_ids,
|
||||
)
|
||||
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(
|
||||
shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype)
|
||||
)
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
|
Loading…
Reference in New Issue
Block a user