From 0875b2509aaef6eafdc9b49856669e04fb4aa9e6 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 20 Jun 2023 12:49:06 +0100 Subject: [PATCH] Allow passing kwargs through to TFBertTokenizer (#24324) --- src/transformers/models/bert/tokenization_bert_tf.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/bert/tokenization_bert_tf.py b/src/transformers/models/bert/tokenization_bert_tf.py index e0e38d68a58..281d222fbda 100644 --- a/src/transformers/models/bert/tokenization_bert_tf.py +++ b/src/transformers/models/bert/tokenization_bert_tf.py @@ -48,7 +48,9 @@ class TFBertTokenizer(tf.keras.layers.Layer): return_attention_mask (`bool`, *optional*, defaults to `True`): Whether to return the attention_mask. use_fast_bert_tokenizer (`bool`, *optional*, defaults to `True`): - If set to false will use standard TF Text BertTokenizer, making it servable by TF Serving. + If True, will use the FastBertTokenizer class from Tensorflow Text. If False, will use the BertTokenizer + class instead. BertTokenizer supports some additional options, but is slower and cannot be exported to + TFLite. """ def __init__( @@ -65,11 +67,12 @@ class TFBertTokenizer(tf.keras.layers.Layer): return_token_type_ids: bool = True, return_attention_mask: bool = True, use_fast_bert_tokenizer: bool = True, + **tokenizer_kwargs, ): super().__init__() if use_fast_bert_tokenizer: self.tf_tokenizer = FastBertTokenizer( - vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case + vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case, **tokenizer_kwargs ) else: lookup_table = tf.lookup.StaticVocabularyTable( @@ -81,7 +84,9 @@ class TFBertTokenizer(tf.keras.layers.Layer): ), num_oov_buckets=1, ) - self.tf_tokenizer = BertTokenizerLayer(lookup_table, token_out_type=tf.int64, lower_case=do_lower_case) + self.tf_tokenizer = BertTokenizerLayer( + lookup_table, token_out_type=tf.int64, lower_case=do_lower_case, **tokenizer_kwargs + ) self.vocab_list = vocab_list self.do_lower_case = do_lower_case