mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Allow passing kwargs through to TFBertTokenizer (#24324)
This commit is contained in:
parent
cfc838dd4d
commit
0875b2509a
@ -48,7 +48,9 @@ class TFBertTokenizer(tf.keras.layers.Layer):
|
||||
return_attention_mask (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return the attention_mask.
|
||||
use_fast_bert_tokenizer (`bool`, *optional*, defaults to `True`):
|
||||
If set to false will use standard TF Text BertTokenizer, making it servable by TF Serving.
|
||||
If True, will use the FastBertTokenizer class from Tensorflow Text. If False, will use the BertTokenizer
|
||||
class instead. BertTokenizer supports some additional options, but is slower and cannot be exported to
|
||||
TFLite.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -65,11 +67,12 @@ class TFBertTokenizer(tf.keras.layers.Layer):
|
||||
return_token_type_ids: bool = True,
|
||||
return_attention_mask: bool = True,
|
||||
use_fast_bert_tokenizer: bool = True,
|
||||
**tokenizer_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if use_fast_bert_tokenizer:
|
||||
self.tf_tokenizer = FastBertTokenizer(
|
||||
vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case
|
||||
vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case, **tokenizer_kwargs
|
||||
)
|
||||
else:
|
||||
lookup_table = tf.lookup.StaticVocabularyTable(
|
||||
@ -81,7 +84,9 @@ class TFBertTokenizer(tf.keras.layers.Layer):
|
||||
),
|
||||
num_oov_buckets=1,
|
||||
)
|
||||
self.tf_tokenizer = BertTokenizerLayer(lookup_table, token_out_type=tf.int64, lower_case=do_lower_case)
|
||||
self.tf_tokenizer = BertTokenizerLayer(
|
||||
lookup_table, token_out_type=tf.int64, lower_case=do_lower_case, **tokenizer_kwargs
|
||||
)
|
||||
|
||||
self.vocab_list = vocab_list
|
||||
self.do_lower_case = do_lower_case
|
||||
|
Loading…
Reference in New Issue
Block a user