mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Fix Tensorflow T5 with int64 input (#13479)
* Fix Tensorflow T5 with int64 input * Style pass
This commit is contained in:
parent
361b6df36a
commit
707105290b
@ -874,16 +874,21 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
|
|||||||
), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
|
), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
|
||||||
|
|
||||||
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
|
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
|
||||||
|
start_tokens = tf.cast(start_tokens, input_ids.dtype) # Ensure compatible dtypes for concatenation
|
||||||
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
|
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
|
||||||
|
|
||||||
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
|
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
|
||||||
# replace possible -100 values in labels by `pad_token_id`
|
# replace possible -100 values in labels by `pad_token_id`
|
||||||
shifted_input_ids = tf.where(
|
shifted_input_ids = tf.where(
|
||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100,
|
||||||
|
tf.cast(tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids.dtype),
|
||||||
|
shifted_input_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# "Verify that `labels` has only positive values and -100"
|
# "Verify that `labels` has only positive values and -100"
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
assert_gte0 = tf.debugging.assert_greater_equal(
|
||||||
|
shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||||
with tf.control_dependencies([assert_gte0]):
|
with tf.control_dependencies([assert_gte0]):
|
||||||
|
Loading…
Reference in New Issue
Block a user