Remove token_type_ids from default TF GPT-2 signature (#26962)

Remove token_type_ids from default GPT-2 signature
This commit is contained in:
Matt 2023-10-23 16:18:02 +01:00 committed by GitHub
parent c0b5ad9473
commit f7354a3bd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -521,6 +521,16 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"h.\d+.attn.bias", r"h.\d+.crossattention.bias"]
@property
def input_signature(self):
# Although GPT-2 supports token_type_ids in theory, in practice they are rarely used, and the implementation
# means that passing token_type_ids=0 yields different outputs from token_type_ids=None.
# Therefore, we remove the token_type_ids argument by default, even though it would usually be included.
return {
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
}
@dataclass
class TFGPT2DoubleHeadsModelOutput(ModelOutput):