Update serving signatures and make sure we actually use them (#19034)

* Override save() to use the serving signature as the default

* Replace int32 with int64 in all our serving signatures

* Remember one very important line so as not to break every test at once

* Dtype fix for TFLED

* dtype fix for shift_tokens_right in general

* Dtype fixes in mBART and RAG

* Fix dtypes for test_unpack_inputs

* More dtype fixes

* Yet more mBART + RAG dtype fixes

* Yet more mBART + RAG dtype fixes

* Add a check that the model actually has a serving method
This commit is contained in:
Matt 2022-09-15 14:34:22 +01:00 committed by GitHub
parent 9b80a0bc18
commit 2322eb8e2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 179 additions and 97 deletions

View File

@ -560,6 +560,18 @@ def input_processing(func, config, **kwargs):
if "kwargs" in output:
del output["kwargs"]
cast_output = dict()
for key, val in output.items():
if isinstance(val, tf.Tensor) and val.dtype == tf.int32:
cast_output[key] = tf.cast(val, tf.int64)
elif isinstance(val, np.ndarray) and val.dtype == np.int32:
cast_output[key] = val.astype(np.int64)
else:
cast_output[key] = val
output = cast_output
del cast_output
if config is not None:
boolean_dict = {
k: v
@ -1054,9 +1066,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int64, name="token_type_ids"),
}
]
)
@ -1082,6 +1094,29 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"""
raise NotImplementedError
def save(
self,
filepath,
overwrite=True,
include_optimizer=True,
save_format=None,
signatures=None,
options=None,
save_traces=True,
):
# Very simple wrapper that ensures we set the correct serving signature when saving
if signatures is None and hasattr(self, "serving"):
signatures = self.serving
super().save(
filepath,
overwrite=overwrite,
include_optimizer=include_optimizer,
save_format=save_format,
signatures=signatures,
options=options,
save_traces=save_traces,
)
def get_input_embeddings(self) -> tf.keras.layers.Layer:
"""
Returns the model's input embeddings layer.

View File

@ -64,11 +64,15 @@ LARGE_NEGATIVE = -1e8
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.fill(
(shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype)
)
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
shifted_input_ids == -100,
tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
shifted_input_ids,
)
# "Verify that `labels` has only positive values and -100"
@ -475,10 +479,10 @@ class TFBartPretrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int64, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int64, name="decoder_attention_mask"),
}
]
)

View File

@ -1799,9 +1799,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
"input_ids": tf.TensorSpec((None, None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int64, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int64, name="token_type_ids"),
}
]
)

View File

@ -66,11 +66,15 @@ LARGE_NEGATIVE = -1e8
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.fill(
(shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype)
)
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
shifted_input_ids == -100,
tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
shifted_input_ids,
)
# "Verify that `labels` has only positive values and -100"

View File

@ -65,11 +65,15 @@ LARGE_NEGATIVE = -1e8
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.fill(
(shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype)
)
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
shifted_input_ids == -100,
tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
shifted_input_ids,
)
# "Verify that `labels` has only positive values and -100"

View File

@ -1097,8 +1097,8 @@ class TFCLIPTextModel(TFCLIPPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
}
]
)

View File

@ -1127,9 +1127,9 @@ class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLos
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
"input_ids": tf.TensorSpec((None, None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int64, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int64, name="token_type_ids"),
}
]
)

View File

@ -424,8 +424,8 @@ class TFDistilBertPreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
}
]
)

View File

@ -376,8 +376,8 @@ class TFDPRPretrainedReader(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
}
]
)

View File

@ -1511,9 +1511,9 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.float32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
"token_type_ids": tf.TensorSpec((None, None), tf.int64, name="token_type_ids"),
}
]
)

View File

@ -548,8 +548,8 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
}
]
)

View File

@ -527,8 +527,8 @@ class TFGPTJPreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
}
]
)

View File

@ -1312,8 +1312,8 @@ class TFHubertPreTrainedModel(TFPreTrainedModel):
input_signature=[
{
"input_values": tf.TensorSpec((None, None), tf.float32, name="input_values"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int64, name="token_type_ids"),
}
]
)

View File

@ -988,10 +988,10 @@ class TFLayoutLMv3PreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"bbox": tf.TensorSpec((None, None, 4), tf.int32, name="bbox"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"bbox": tf.TensorSpec((None, None, 4), tf.int64, name="bbox"),
"pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
}
]
)

View File

@ -55,16 +55,23 @@ _TOKENIZER_FOR_DOC = "LEDTokenizer"
LARGE_NEGATIVE = -1e8
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
start_tokens = tf.fill(
(shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype)
)
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
shifted_input_ids == -100,
tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
shifted_input_ids,
)
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):

View File

@ -800,12 +800,12 @@ class TFLxmertPreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
"visual_feats": tf.TensorSpec((None, None, None), tf.float32, name="visual_feats"),
"visual_pos": tf.TensorSpec((None, None, None), tf.float32, name="visual_pos"),
"visual_attention_mask": tf.TensorSpec((None, None), tf.int32, name="visual_attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
"visual_attention_mask": tf.TensorSpec((None, None), tf.int64, name="visual_attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int64, name="token_type_ids"),
}
]
)

View File

@ -65,11 +65,15 @@ LARGE_NEGATIVE = -1e8
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.fill(
(shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype)
)
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
shifted_input_ids == -100,
tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
shifted_input_ids,
)
# "Verify that `labels` has only positive values and -100"

View File

@ -69,11 +69,15 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int):
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
input_ids = tf.where(input_ids == -100, tf.fill(shape_list(input_ids), pad_token_id), input_ids)
input_ids = tf.where(
input_ids == -100, tf.fill(shape_list(input_ids), tf.cast(pad_token_id, input_ids.dtype)), input_ids
)
language_id_index = (
tf.reduce_sum(tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=input_ids.dtype), axis=-1) - 1
)
language_id_index = tf.stack([tf.range(shape_list(input_ids)[0]), language_id_index], axis=-1)
language_id_index = tf.stack(
[tf.range(shape_list(input_ids)[0], dtype=input_ids.dtype), language_id_index], axis=-1
)
languages_ids = tf.gather_nd(input_ids, language_id_index)
shifted_input_ids = tf.concat([tf.expand_dims(languages_ids, axis=-1), input_ids[:, :-1]], axis=-1)

View File

@ -76,8 +76,8 @@ class TFMPNetPreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
}
]
)

View File

@ -359,8 +359,8 @@ class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
}
]
)

View File

@ -425,8 +425,8 @@ class TFOPTPreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
}
]
)

View File

@ -65,11 +65,15 @@ LARGE_NEGATIVE = -1e8
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.fill(
(shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype)
)
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
shifted_input_ids == -100,
tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
shifted_input_ids,
)
# "Verify that `labels` has only positive values and -100"

View File

@ -1301,17 +1301,18 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
pad_token_id = self.generator.config.pad_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
shifted_input_ids = tf.cast(input_ids, tf.int32)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), start_token_id)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, :-1]], -1)
start_tokens = tf.fill((shape_list(input_ids)[0], 1), tf.cast(start_token_id, input_ids.dtype))
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
shifted_input_ids == -100,
tf.fill(shape_list(shifted_input_ids), tf.cast(pad_token_id, input_ids.dtype)),
shifted_input_ids,
)
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, shifted_input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
@ -1324,7 +1325,10 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
n_docs = n_docs if n_docs is not None else self.config.n_docs
# shift tokens left (from original Pytorch's version)
target = tf.concat([target[:, 1:], tf.fill([target.shape[0], 1], self.config.generator.pad_token_id)], axis=1)
target = tf.concat(
[target[:, 1:], tf.fill([target.shape[0], 1], tf.cast(self.config.generator.pad_token_id, target.dtype))],
axis=1,
)
rag_logprobs = self.marginalize(seq_logits, doc_scores, n_docs)
loss = self.hf_compute_loss(target, rag_logprobs, from_logits=True, reduce_loss=reduce_loss)
@ -1571,7 +1575,10 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, exclude_bos_score=False, n_docs=None
):
# shift tokens left
target = tf.concat([target[:, 1:], tf.fill([target.shape[0], 1], self.config.generator.pad_token_id)], axis=1)
target = tf.concat(
[target[:, 1:], tf.fill([target.shape[0], 1], tf.cast(self.config.generator.pad_token_id, target.dtype))],
axis=1,
)
# bos_token_id is None for T5
bos_token_id = self.config.bos_token_id or self.config.generator.bos_token_id
@ -1580,7 +1587,7 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
use_bos = bos_token_id is not None and equal_bos_token_id_all
def _mask_pads(ll, smooth_obj):
pad_mask = tf.equal(target, self.config.generator.pad_token_id)
pad_mask = tf.equal(target, tf.cast(self.config.generator.pad_token_id, target.dtype))
if tf.reduce_any(pad_mask):
ll = tf.where(pad_mask, 0.0, ll)
smooth_obj = tf.where(pad_mask, 0.0, smooth_obj)
@ -1611,7 +1618,7 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
def torch_gather(param, id_tensor):
# 2d-gather torch equivalent: https://stackoverflow.com/questions/52129909/tensorflow-equivalent-of-torch-gather
def gather2d(target, id_tensor):
idx = tf.stack([tf.range(tf.shape(id_tensor)[0]), id_tensor[:, 0]], axis=-1)
idx = tf.stack([tf.range(tf.shape(id_tensor)[0], dtype=id_tensor.dtype), id_tensor[:, 0]], axis=-1)
result = tf.gather_nd(target, idx)
return tf.expand_dims(result, axis=-1)

View File

@ -1435,9 +1435,9 @@ class TFRemBertForMultipleChoice(TFRemBertPreTrainedModel, TFMultipleChoiceLoss)
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
"input_ids": tf.TensorSpec((None, None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int64, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int64, name="token_type_ids"),
}
]
)

View File

@ -798,8 +798,8 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
}
]
)

View File

@ -1211,9 +1211,9 @@ class TFRoFormerForMultipleChoice(TFRoFormerPreTrainedModel, TFMultipleChoiceLos
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
"input_ids": tf.TensorSpec((None, None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int64, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int64, name="token_type_ids"),
}
]
)

View File

@ -67,11 +67,15 @@ LARGE_NEGATIVE = -1e8
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.fill(
(shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype)
)
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
shifted_input_ids == -100,
tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
shifted_input_ids,
)
# "Verify that `labels` has only positive values and -100"
@ -591,9 +595,9 @@ class TFSpeech2TextPreTrainedModel(TFPreTrainedModel):
input_signature=[
{
"input_features": tf.TensorSpec((None, None, None), tf.float32, name="input_features"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int64, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int64, name="decoder_attention_mask"),
}
]
)

View File

@ -872,10 +872,10 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int64, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int64, name="decoder_attention_mask"),
}
]
)

View File

@ -865,9 +865,9 @@ class TFTapasPreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.float32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int64, name="token_type_ids"),
}
]
)

View File

@ -686,7 +686,7 @@ class TFTransfoXLPreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
}
]
)

View File

@ -1345,8 +1345,8 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel):
input_signature=[
{
"input_values": tf.TensorSpec((None, None), tf.float32, name="input_values"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int64, name="token_type_ids"),
}
]
)

View File

@ -636,8 +636,8 @@ class TFXGLMPreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
}
]
)

View File

@ -1563,9 +1563,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
"input_ids": tf.TensorSpec((None, None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int64, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int64, name="token_type_ids"),
}
]
)

View File

@ -1685,16 +1685,21 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
LARGE_NEGATIVE = -1e8
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
start_tokens = tf.fill((shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype))
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
shifted_input_ids == -100,
tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
shifted_input_ids,
)
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):

View File

@ -1887,9 +1887,9 @@ class UtilsFunctionsTest(unittest.TestCase):
return pixel_values, output_attentions, output_hidden_states, return_dict
dummy_model = DummyModel()
input_ids = tf.constant([0, 1, 2, 3])
past = tf.constant([4, 5, 6, 7])
pixel_values = tf.constant([8, 9, 10, 11])
input_ids = tf.constant([0, 1, 2, 3], dtype=tf.int64)
past = tf.constant([4, 5, 6, 7], dtype=tf.int64)
pixel_values = tf.constant([8, 9, 10, 11], dtype=tf.int64)
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
output = dummy_model.call(input_ids=input_ids, past=past)