mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 22:38:58 +06:00
add in layer gpt2 tokenizer (#20421)
* add minimal working gpt2 tokenizer * graph mode and output equivalence tests working * not today tensorflow. serialization test passing! * fix style, documentation, docstrings and all that jazz * passing consistency checks * move keras nlp to tf dependencies * fix tf modeling utils and gpt2 attention to enable compiling * fix (I hope) keras nlp dependencies * rever changes on generation * remove debug prints * remove redundant tf dummy objects * add from config, get config and max length settings to address review * let flake ignore the error on distillation you are welcome * test from config * add padding test * address sgugger review
This commit is contained in:
parent
e8d448edcf
commit
fb2b45e562
@ -138,6 +138,10 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
|||||||
|
|
||||||
[[autodoc]] modeling_tf_outputs.TFSequenceClassifierOutputWithPast
|
[[autodoc]] modeling_tf_outputs.TFSequenceClassifierOutputWithPast
|
||||||
|
|
||||||
|
## TFGPT2Tokenizer
|
||||||
|
|
||||||
|
[[autodoc]] TFGPT2Tokenizer
|
||||||
|
|
||||||
## FlaxGPT2Model
|
## FlaxGPT2Model
|
||||||
|
|
||||||
[[autodoc]] FlaxGPT2Model
|
[[autodoc]] FlaxGPT2Model
|
||||||
|
@ -5,7 +5,7 @@ import gc
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List # noqa: F401
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
6
setup.py
6
setup.py
@ -124,6 +124,7 @@ _deps = [
|
|||||||
"jaxlib>=0.1.65,<=0.3.6",
|
"jaxlib>=0.1.65,<=0.3.6",
|
||||||
"jieba",
|
"jieba",
|
||||||
"kenlm",
|
"kenlm",
|
||||||
|
"keras-nlp>=0.3.1",
|
||||||
"nltk",
|
"nltk",
|
||||||
"natten>=0.14.4",
|
"natten>=0.14.4",
|
||||||
"numpy>=1.17",
|
"numpy>=1.17",
|
||||||
@ -241,14 +242,13 @@ class DepsTableUpdateCommand(Command):
|
|||||||
with open(target, "w", encoding="utf-8", newline="\n") as f:
|
with open(target, "w", encoding="utf-8", newline="\n") as f:
|
||||||
f.write("\n".join(content))
|
f.write("\n".join(content))
|
||||||
|
|
||||||
|
|
||||||
extras = {}
|
extras = {}
|
||||||
|
|
||||||
extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic", "sudachipy", "sudachidict_core", "pyknp")
|
extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic", "sudachipy", "sudachidict_core", "pyknp")
|
||||||
extras["sklearn"] = deps_list("scikit-learn")
|
extras["sklearn"] = deps_list("scikit-learn")
|
||||||
|
|
||||||
extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text")
|
extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp")
|
||||||
extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx", "tensorflow-text")
|
extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp")
|
||||||
|
|
||||||
extras["torch"] = deps_list("torch")
|
extras["torch"] = deps_list("torch")
|
||||||
extras["accelerate"] = deps_list("accelerate")
|
extras["accelerate"] = deps_list("accelerate")
|
||||||
|
@ -32,6 +32,7 @@ from .utils import (
|
|||||||
OptionalDependencyNotAvailable,
|
OptionalDependencyNotAvailable,
|
||||||
_LazyModule,
|
_LazyModule,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
|
is_keras_nlp_available,
|
||||||
is_sentencepiece_available,
|
is_sentencepiece_available,
|
||||||
is_speech_available,
|
is_speech_available,
|
||||||
is_tensorflow_text_available,
|
is_tensorflow_text_available,
|
||||||
@ -694,6 +695,19 @@ except OptionalDependencyNotAvailable:
|
|||||||
else:
|
else:
|
||||||
_import_structure["models.bert"].append("TFBertTokenizer")
|
_import_structure["models.bert"].append("TFBertTokenizer")
|
||||||
|
|
||||||
|
# keras-nlp-specific objects
|
||||||
|
try:
|
||||||
|
if not is_keras_nlp_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from .utils import dummy_keras_nlp_objects
|
||||||
|
|
||||||
|
_import_structure["utils.dummy_keras_nlp_objects"] = [
|
||||||
|
name for name in dir(dummy_keras_nlp_objects) if not name.startswith("_")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
_import_structure["models.gpt2"].append("TFGPT2Tokenizer")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not (is_sentencepiece_available() and is_speech_available()):
|
if not (is_sentencepiece_available() and is_speech_available()):
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
@ -3828,6 +3842,14 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
from .models.bert import TFBertTokenizer
|
from .models.bert import TFBertTokenizer
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_keras_nlp_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from .utils.dummy_keras_nlp_objects import *
|
||||||
|
else:
|
||||||
|
from .models.gpt2 import TFGPT2Tokenizer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not (is_speech_available() and is_sentencepiece_available()):
|
if not (is_speech_available() and is_sentencepiece_available()):
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
|
@ -30,6 +30,7 @@ deps = {
|
|||||||
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
|
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
|
||||||
"jieba": "jieba",
|
"jieba": "jieba",
|
||||||
"kenlm": "kenlm",
|
"kenlm": "kenlm",
|
||||||
|
"keras-nlp": "keras-nlp>=0.3.1",
|
||||||
"nltk": "nltk",
|
"nltk": "nltk",
|
||||||
"natten": "natten>=0.14.4",
|
"natten": "natten>=0.14.4",
|
||||||
"numpy": "numpy>=1.17",
|
"numpy": "numpy>=1.17",
|
||||||
|
@ -22,6 +22,7 @@ from ...utils import (
|
|||||||
OptionalDependencyNotAvailable,
|
OptionalDependencyNotAvailable,
|
||||||
_LazyModule,
|
_LazyModule,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
|
is_keras_nlp_available,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_tokenizers_available,
|
is_tokenizers_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@ -74,6 +75,14 @@ else:
|
|||||||
"TFGPT2PreTrainedModel",
|
"TFGPT2PreTrainedModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_keras_nlp_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
_import_structure["tokenization_gpt2_tf"] = ["TFGPT2Tokenizer"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_flax_available():
|
if not is_flax_available():
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
@ -127,6 +136,14 @@ if TYPE_CHECKING:
|
|||||||
TFGPT2PreTrainedModel,
|
TFGPT2PreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_keras_nlp_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
from .tokenization_gpt2_tf import TFGPT2Tokenizer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_flax_available():
|
if not is_flax_available():
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
|
104
src/transformers/models/gpt2/tokenization_gpt2_tf.py
Normal file
104
src/transformers/models/gpt2/tokenization_gpt2_tf.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import os
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from keras_nlp.tokenizers import BytePairTokenizer
|
||||||
|
from tensorflow_text import pad_model_inputs
|
||||||
|
|
||||||
|
from .tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TFGPT2Tokenizer(tf.keras.layers.Layer):
|
||||||
|
"""
|
||||||
|
This is an in-graph tokenizer for GPT2. It should be initialized similarly to other tokenizers, using the
|
||||||
|
`from_pretrained()` method. It can also be initialized with the `from_tokenizer()` method, which imports settings
|
||||||
|
from an existing standard tokenizer object.
|
||||||
|
|
||||||
|
In-graph tokenizers, unlike other Hugging Face tokenizers, are actually Keras layers and are designed to be run
|
||||||
|
when the model is called, rather than during preprocessing. As a result, they have somewhat more limited options
|
||||||
|
than standard tokenizer classes. They are most useful when you want to create an end-to-end model that goes
|
||||||
|
straight from `tf.string` inputs to outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab (Dict[str, int]): Vocabulary dict for Byte Pair Tokenizer
|
||||||
|
merges (List[str]): Merges list for Byte Pair Tokenizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vocab: Dict[str, int], merges: List[str], max_length: int = None, pad_token_id: int = None):
|
||||||
|
super().__init__()
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.max_length = max_length
|
||||||
|
self.vocab = vocab
|
||||||
|
self.merges = merges
|
||||||
|
self.tf_tokenizer = BytePairTokenizer(vocab, merges, sequence_length=max_length)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tokenizer(cls, tokenizer: GPT2Tokenizer, *args, **kwargs):
|
||||||
|
"""Creates TFGPT2Tokenizer from GPT2Tokenizer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer (GPT2Tokenizer)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import AutoTokenizer, TFGPT2Tokenizer
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||||
|
tf_tokenizer = TFGPT2Tokenizer.from_tokenizer(tokenizer)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
merges = [" ".join(m) for m in tokenizer.bpe_ranks.keys()]
|
||||||
|
vocab = tokenizer.get_vocab()
|
||||||
|
return cls(vocab, merges, *args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs):
|
||||||
|
"""Creates TFGPT2Tokenizer from pretrained GPT2Tokenizer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained_model_name_or_path (Union[str, os.PathLike]): Path to pretrained model
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import TFGPT2Tokenizer
|
||||||
|
|
||||||
|
tf_tokenizer = TFGPT2Tokenizer.from_pretrained("gpt2")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
|
||||||
|
return cls.from_tokenizer(tokenizer, *init_inputs, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config):
|
||||||
|
"""Creates TFGPT2Tokenizer from configurations
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Dict): Dictionary with keys such as stated in `get_config`.
|
||||||
|
"""
|
||||||
|
return cls(**config)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {
|
||||||
|
"vocab": self.vocab,
|
||||||
|
"merges": self.merges,
|
||||||
|
"max_length": self.max_length,
|
||||||
|
"pad_token_id": self.pad_token_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
def call(self, x, max_length: int = None):
|
||||||
|
input_ids = self.tf_tokenizer(x)
|
||||||
|
attention_mask = tf.ones_like(input_ids)
|
||||||
|
|
||||||
|
if self.pad_token_id is not None:
|
||||||
|
# pad the tokens up to max length
|
||||||
|
max_length = max_length if max_length is not None else self.max_length
|
||||||
|
|
||||||
|
if max_length is not None:
|
||||||
|
input_ids, attention_mask = pad_model_inputs(
|
||||||
|
input_ids, max_seq_length=max_length, pad_value=self.pad_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"attention_mask": attention_mask, "input_ids": input_ids}
|
@ -111,6 +111,7 @@ from .import_utils import (
|
|||||||
is_ipex_available,
|
is_ipex_available,
|
||||||
is_jumanpp_available,
|
is_jumanpp_available,
|
||||||
is_kenlm_available,
|
is_kenlm_available,
|
||||||
|
is_keras_nlp_available,
|
||||||
is_librosa_available,
|
is_librosa_available,
|
||||||
is_more_itertools_available,
|
is_more_itertools_available,
|
||||||
is_natten_available,
|
is_natten_available,
|
||||||
|
10
src/transformers/utils/dummy_keras_nlp_objects.py
Normal file
10
src/transformers/utils/dummy_keras_nlp_objects.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||||
|
# flake8: noqa
|
||||||
|
from ..utils import DummyObject, requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
class TFGPT2Tokenizer(metaclass=DummyObject):
|
||||||
|
_backends = ["keras_nlp"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["keras_nlp"])
|
@ -572,6 +572,10 @@ def is_tensorflow_text_available():
|
|||||||
return importlib.util.find_spec("tensorflow_text") is not None
|
return importlib.util.find_spec("tensorflow_text") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_keras_nlp_available():
|
||||||
|
return importlib.util.find_spec("keras_nlp") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_in_notebook():
|
def is_in_notebook():
|
||||||
try:
|
try:
|
||||||
# Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
|
# Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
|
||||||
|
130
tests/models/gpt2/test_tokenization_gpt2_tf.py
Normal file
130
tests/models/gpt2/test_tokenization_gpt2_tf.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
|
from transformers import AutoConfig, TFGPT2LMHeadModel, is_tensorflow_text_available, is_tf_available
|
||||||
|
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
from transformers.testing_utils import require_tensorflow_text, slow
|
||||||
|
|
||||||
|
|
||||||
|
if is_tensorflow_text_available():
|
||||||
|
from transformers.models.gpt2 import TFGPT2Tokenizer
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
TOKENIZER_CHECKPOINTS = ["gpt2"]
|
||||||
|
TINY_MODEL_CHECKPOINT = "gpt2"
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
|
||||||
|
class ModelToSave(tf.Module):
|
||||||
|
def __init__(self, tokenizer):
|
||||||
|
super().__init__()
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
config = AutoConfig.from_pretrained(TINY_MODEL_CHECKPOINT)
|
||||||
|
self.model = TFGPT2LMHeadModel.from_config(config)
|
||||||
|
|
||||||
|
@tf.function(input_signature=(tf.TensorSpec((None,), tf.string, name="text"),))
|
||||||
|
def serving(self, text):
|
||||||
|
|
||||||
|
tokenized = self.tokenizer(text)
|
||||||
|
input_ids_dense = tokenized["input_ids"].to_tensor()
|
||||||
|
|
||||||
|
input_mask = tf.cast(input_ids_dense > 0, tf.int32)
|
||||||
|
# input_mask = tf.reshape(input_mask, [-1, MAX_SEQ_LEN])
|
||||||
|
|
||||||
|
outputs = self.model(input_ids=input_ids_dense, attention_mask=input_mask)["logits"]
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@require_tensorflow_text
|
||||||
|
class GPTTokenizationTest(unittest.TestCase):
|
||||||
|
# The TF tokenizers are usually going to be used as pretrained tokenizers from existing model checkpoints,
|
||||||
|
# so that's what we focus on here.
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
self.tokenizers = [GPT2Tokenizer.from_pretrained(checkpoint) for checkpoint in (TOKENIZER_CHECKPOINTS)]
|
||||||
|
self.tf_tokenizers = [TFGPT2Tokenizer.from_pretrained(checkpoint) for checkpoint in TOKENIZER_CHECKPOINTS]
|
||||||
|
assert len(self.tokenizers) == len(self.tf_tokenizers)
|
||||||
|
|
||||||
|
self.test_sentences = [
|
||||||
|
"This is a straightforward English test sentence.",
|
||||||
|
"This one has some weird characters\rto\nsee\r\nif those\u00E9break things.",
|
||||||
|
"Now we're going to add some Chinese: 一 二 三 一二三",
|
||||||
|
"And some much more rare Chinese: 齉 堃 齉堃",
|
||||||
|
"Je vais aussi écrire en français pour tester les accents",
|
||||||
|
"Classical Irish also has some unusual characters, so in they go: Gaelaċ, ꝼ",
|
||||||
|
]
|
||||||
|
self.paired_sentences = list(zip(self.test_sentences, self.test_sentences[::-1]))
|
||||||
|
|
||||||
|
def test_output_equivalence(self):
|
||||||
|
for tokenizer, tf_tokenizer in zip(self.tokenizers, self.tf_tokenizers):
|
||||||
|
for test_inputs in self.test_sentences:
|
||||||
|
python_outputs = tokenizer([test_inputs], return_tensors="tf")
|
||||||
|
tf_outputs = tf_tokenizer([test_inputs])
|
||||||
|
|
||||||
|
for key in python_outputs.keys():
|
||||||
|
# convert them to numpy to avoid messing with ragged tensors
|
||||||
|
python_outputs_values = python_outputs[key].numpy()
|
||||||
|
tf_outputs_values = tf_outputs[key].numpy()
|
||||||
|
|
||||||
|
self.assertTrue(tf.reduce_all(python_outputs_values.shape == tf_outputs_values.shape))
|
||||||
|
self.assertTrue(tf.reduce_all(tf.cast(python_outputs_values, tf.int64) == tf_outputs_values))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_graph_mode(self):
|
||||||
|
for tf_tokenizer in self.tf_tokenizers:
|
||||||
|
compiled_tokenizer = tf.function(tf_tokenizer)
|
||||||
|
for test_inputs in self.test_sentences:
|
||||||
|
test_inputs = tf.constant(test_inputs)
|
||||||
|
compiled_outputs = compiled_tokenizer(test_inputs)
|
||||||
|
eager_outputs = tf_tokenizer(test_inputs)
|
||||||
|
|
||||||
|
for key in eager_outputs.keys():
|
||||||
|
self.assertTrue(tf.reduce_all(eager_outputs[key] == compiled_outputs[key]))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_saved_model(self):
|
||||||
|
for tf_tokenizer in self.tf_tokenizers:
|
||||||
|
model = ModelToSave(tokenizer=tf_tokenizer)
|
||||||
|
test_inputs = tf.convert_to_tensor([self.test_sentences[0]])
|
||||||
|
out = model.serving(test_inputs) # Build model with some sample inputs
|
||||||
|
with TemporaryDirectory() as tempdir:
|
||||||
|
save_path = Path(tempdir) / "saved.model"
|
||||||
|
tf.saved_model.save(model, save_path, signatures={"serving_default": model.serving})
|
||||||
|
loaded_model = tf.saved_model.load(save_path)
|
||||||
|
loaded_output = loaded_model.signatures["serving_default"](test_inputs)["output_0"]
|
||||||
|
# We may see small differences because the loaded model is compiled, so we need an epsilon for the test
|
||||||
|
self.assertTrue(tf.reduce_all(out == loaded_output))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_from_config(self):
|
||||||
|
for tf_tokenizer in self.tf_tokenizers:
|
||||||
|
test_inputs = tf.convert_to_tensor([self.test_sentences[0]])
|
||||||
|
out = tf_tokenizer(test_inputs) # Build model with some sample inputs
|
||||||
|
|
||||||
|
config = tf_tokenizer.get_config()
|
||||||
|
model_from_config = TFGPT2Tokenizer.from_config(config)
|
||||||
|
from_config_output = model_from_config(test_inputs)
|
||||||
|
|
||||||
|
for key in from_config_output.keys():
|
||||||
|
self.assertTrue(tf.reduce_all(from_config_output[key] == out[key]))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_padding(self):
|
||||||
|
for tf_tokenizer in self.tf_tokenizers:
|
||||||
|
# for the test to run
|
||||||
|
tf_tokenizer.pad_token_id = 123123
|
||||||
|
|
||||||
|
for max_length in [3, 5, 1024]:
|
||||||
|
test_inputs = tf.convert_to_tensor([self.test_sentences[0]])
|
||||||
|
out = tf_tokenizer(test_inputs, max_length=max_length)
|
||||||
|
|
||||||
|
out_length = out["input_ids"].numpy().shape[1]
|
||||||
|
|
||||||
|
assert out_length == max_length
|
Loading…
Reference in New Issue
Block a user