Allow passing kwargs through to TFBertTokenizer (#24324)

This commit is contained in:
Matt 2023-06-20 12:49:06 +01:00 committed by GitHub
parent cfc838dd4d
commit 0875b2509a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -48,7 +48,9 @@ class TFBertTokenizer(tf.keras.layers.Layer):
return_attention_mask (`bool`, *optional*, defaults to `True`):
Whether to return the attention_mask.
use_fast_bert_tokenizer (`bool`, *optional*, defaults to `True`):
If set to false will use standard TF Text BertTokenizer, making it servable by TF Serving.
If True, will use the FastBertTokenizer class from Tensorflow Text. If False, will use the BertTokenizer
class instead. BertTokenizer supports some additional options, but is slower and cannot be exported to
TFLite.
"""
def __init__(
@ -65,11 +67,12 @@ class TFBertTokenizer(tf.keras.layers.Layer):
return_token_type_ids: bool = True,
return_attention_mask: bool = True,
use_fast_bert_tokenizer: bool = True,
**tokenizer_kwargs,
):
super().__init__()
if use_fast_bert_tokenizer:
self.tf_tokenizer = FastBertTokenizer(
vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case
vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case, **tokenizer_kwargs
)
else:
lookup_table = tf.lookup.StaticVocabularyTable(
@ -81,7 +84,9 @@ class TFBertTokenizer(tf.keras.layers.Layer):
),
num_oov_buckets=1,
)
self.tf_tokenizer = BertTokenizerLayer(lookup_table, token_out_type=tf.int64, lower_case=do_lower_case)
self.tf_tokenizer = BertTokenizerLayer(
lookup_table, token_out_type=tf.int64, lower_case=do_lower_case, **tokenizer_kwargs
)
self.vocab_list = vocab_list
self.do_lower_case = do_lower_case