Adds pretrained IDs directly in the tests (#29534)

* Adds pretrained IDs directly in the tests

* Fix tests

* Fix tests

* Review!
This commit is contained in:
Lysandre Debut 2024-03-13 14:53:27 +01:00 committed by GitHub
parent 38bff8c84f
commit 11bbb505c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
89 changed files with 95 additions and 8 deletions

View File

@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model")
@require_sentencepiece
@require_tokenizers
class AlbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "albert/albert-base-v1"
tokenizer_class = AlbertTokenizer
rust_tokenizer_class = AlbertTokenizerFast
test_rust_tokenizer = True

View File

@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_roberta_det
@require_tokenizers
class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/bart-base"
tokenizer_class = BartTokenizer
rust_tokenizer_class = BartTokenizerFast
test_rust_tokenizer = True

View File

@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_sentencepiece
@slow # see https://github.com/huggingface/transformers/issues/11457
class BarthezTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "moussaKam/mbarthez"
tokenizer_class = BarthezTokenizer
rust_tokenizer_class = BarthezTokenizerFast
test_rust_tokenizer = True

View File

@ -26,6 +26,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_bpe.model")
class BartphoTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "vinai/bartpho-syllable"
tokenizer_class = BartphoTokenizer
test_rust_tokenizer = False
test_sentencepiece = True

View File

@ -34,6 +34,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
@require_tokenizers
class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google-bert/bert-base-uncased"
tokenizer_class = BertTokenizer
rust_tokenizer_class = BertTokenizerFast
test_rust_tokenizer = True

View File

@ -29,6 +29,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_sentencepiece
class BertGenerationTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/bert_for_seq_generation_L-24_bbc_encoder"
tokenizer_class = BertGenerationTokenizer
test_rust_tokenizer = False
test_sentencepiece = True

View File

@ -36,6 +36,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@custom_tokenizers
class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "cl-tohoku/bert-base-japanese"
tokenizer_class = BertJapaneseTokenizer
test_rust_tokenizer = False
space_between_special_tokens = True
@ -403,6 +404,7 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@custom_tokenizers
class BertJapaneseCharacterTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "cl-tohoku/bert-base-japanese"
tokenizer_class = BertJapaneseTokenizer
test_rust_tokenizer = False

View File

@ -22,6 +22,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
class BertweetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "vinai/bertweet-base"
tokenizer_class = BertweetTokenizer
test_rust_tokenizer = False

View File

@ -30,6 +30,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_sentencepiece
@require_tokenizers
class BigBirdTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/bigbird-roberta-base"
tokenizer_class = BigBirdTokenizer
rust_tokenizer_class = BigBirdTokenizerFast
test_rust_tokenizer = True

View File

@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_sacremoses
class BioGptTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/biogpt"
tokenizer_class = BioGptTokenizer
test_rust_tokenizer = False

View File

@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
class BlenderbotSmallTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/blenderbot_small-90M"
tokenizer_class = BlenderbotSmallTokenizer
test_rust_tokenizer = False

View File

@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class BloomTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "bigscience/tokenizer"
slow_tokenizer_class = None
rust_tokenizer_class = BloomTokenizerFast
tokenizer_class = BloomTokenizerFast

View File

@ -32,6 +32,7 @@ FRAMEWORK = "pt" if is_torch_available() else "tf"
@require_sentencepiece
@require_tokenizers
class CamembertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "almanach/camembert-base"
tokenizer_class = CamembertTokenizer
rust_tokenizer_class = CamembertTokenizerFast
test_rust_tokenizer = True

View File

@ -28,6 +28,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
class CanineTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "nielsr/canine-s"
tokenizer_class = CanineTokenizer
test_rust_tokenizer = False

View File

@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class CLIPTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "openai/clip-vit-base-patch32"
tokenizer_class = CLIPTokenizer
rust_tokenizer_class = CLIPTokenizerFast
test_rust_tokenizer = True

View File

@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, slow
class ClvpTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "susnato/clvp_dev"
tokenizer_class = ClvpTokenizer
test_rust_tokenizer = False
from_pretrained_kwargs = {"add_prefix_space": True}

View File

@ -51,6 +51,7 @@ if is_torch_available():
@require_sentencepiece
@require_tokenizers
class CodeLlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "hf-internal-testing/llama-code-tokenizer"
tokenizer_class = CodeLlamaTokenizer
rust_tokenizer_class = CodeLlamaTokenizerFast
test_rust_tokenizer = False

View File

@ -28,6 +28,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class CodeGenTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "Salesforce/codegen-350M-mono"
tokenizer_class = CodeGenTokenizer
rust_tokenizer_class = CodeGenTokenizerFast
test_rust_tokenizer = True

View File

@ -24,6 +24,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_jieba
class CPMAntTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "openbmb/cpm-ant-10b"
tokenizer_class = CpmAntTokenizer
test_rust_tokenizer = False

View File

@ -23,6 +23,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
class CTRLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "Salesforce/ctrl"
tokenizer_class = CTRLTokenizer
test_rust_tokenizer = False
test_seq2seq = False

View File

@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
class DebertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/deberta-base"
tokenizer_class = DebertaTokenizer
test_rust_tokenizer = True
rust_tokenizer_class = DebertaTokenizerFast

View File

@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model")
@require_sentencepiece
@require_tokenizers
class DebertaV2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/deberta-v2-xlarge"
tokenizer_class = DebertaV2Tokenizer
rust_tokenizer_class = DebertaV2TokenizerFast
test_sentencepiece = True

View File

@ -25,6 +25,7 @@ class DistilBertTokenizationTest(BertTokenizationTest):
tokenizer_class = DistilBertTokenizer
rust_tokenizer_class = DistilBertTokenizerFast
test_rust_tokenizer = True
from_pretrained_id = "distilbert/distilbert-base-uncased"
@slow
def test_sequence_builders(self):

View File

@ -33,6 +33,7 @@ class DPRContextEncoderTokenizationTest(BertTokenizationTest):
tokenizer_class = DPRContextEncoderTokenizer
rust_tokenizer_class = DPRContextEncoderTokenizerFast
test_rust_tokenizer = True
from_pretrained_id = "facebook/dpr-ctx_encoder-single-nq-base"
@require_tokenizers
@ -40,6 +41,7 @@ class DPRQuestionEncoderTokenizationTest(BertTokenizationTest):
tokenizer_class = DPRQuestionEncoderTokenizer
rust_tokenizer_class = DPRQuestionEncoderTokenizerFast
test_rust_tokenizer = True
from_pretrained_id = "facebook/dpr-ctx_encoder-single-nq-base"
@require_tokenizers
@ -47,6 +49,7 @@ class DPRReaderTokenizationTest(BertTokenizationTest):
tokenizer_class = DPRReaderTokenizer
rust_tokenizer_class = DPRReaderTokenizerFast
test_rust_tokenizer = True
from_pretrained_id = "facebook/dpr-ctx_encoder-single-nq-base"
@slow
def test_decode_best_spans(self):

View File

@ -33,6 +33,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
@require_tokenizers
class ElectraTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/electra-small-generator"
tokenizer_class = ElectraTokenizer
rust_tokenizer_class = ElectraTokenizerFast
test_rust_tokenizer = True

View File

@ -28,6 +28,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model")
@require_sentencepiece
@require_tokenizers
class ErnieMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "susnato/ernie-m-base_pytorch"
tokenizer_class = ErnieMTokenizer
test_seq2seq = False
test_sentencepiece = True

View File

@ -24,6 +24,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_g2p_en
class FastSpeech2ConformerTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "espnet/fastspeech2_conformer"
tokenizer_class = FastSpeech2ConformerTokenizer
test_rust_tokenizer = False

View File

@ -28,6 +28,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model")
@require_sentencepiece
@require_tokenizers
class FNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/fnet-base"
tokenizer_class = FNetTokenizer
rust_tokenizer_class = FNetTokenizerFast
test_rust_tokenizer = True

View File

@ -30,6 +30,7 @@ FSMT_TINY2 = "stas/tiny-wmt19-en-ru"
class FSMTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "stas/tiny-wmt19-en-de"
tokenizer_class = FSMTTokenizer
test_rust_tokenizer = False

View File

@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class FunnelTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "funnel-transformer/small"
tokenizer_class = FunnelTokenizer
rust_tokenizer_class = FunnelTokenizerFast
test_rust_tokenizer = True

View File

@ -49,6 +49,7 @@ if is_torch_available():
@require_sentencepiece
@require_tokenizers
class GemmaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/gemma-7b"
tokenizer_class = GemmaTokenizer
rust_tokenizer_class = GemmaTokenizerFast

View File

@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "openai-community/gpt2"
tokenizer_class = GPT2Tokenizer
rust_tokenizer_class = GPT2TokenizerFast
test_rust_tokenizer = True

View File

@ -29,6 +29,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class GPTNeoXJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "abeja/gpt-neox-japanese-2.7b"
tokenizer_class = GPTNeoXJapaneseTokenizer
test_rust_tokenizer = False
from_pretrained_kwargs = {"do_clean_text": False, "add_prefix_space": False}

View File

@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_with_bytefallback.mode
@require_sentencepiece
@require_tokenizers
class GPTSw3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "AI-Sweden-Models/gpt-sw3-126m"
tokenizer_class = GPTSw3Tokenizer
test_rust_tokenizer = False
test_sentencepiece = True

View File

@ -29,6 +29,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class GPTSanJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "Tanrei/GPTSAN-japanese"
tokenizer_class = GPTSanJapaneseTokenizer
test_rust_tokenizer = False
from_pretrained_kwargs = {"do_clean_text": False, "add_prefix_space": False}

View File

@ -28,6 +28,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_sacremoses
@require_tokenizers
class HerbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "allegro/herbert-base-cased"
tokenizer_class = HerbertTokenizer
rust_tokenizer_class = HerbertTokenizerFast
test_rust_tokenizer = True

View File

@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class LayoutLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/layoutlm-base-uncased"
tokenizer_class = LayoutLMTokenizer
rust_tokenizer_class = LayoutLMTokenizerFast
test_rust_tokenizer = True

View File

@ -61,6 +61,7 @@ logger = logging.get_logger(__name__)
@require_tokenizers
@require_pandas
class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/layoutlmv2-base-uncased"
tokenizer_class = LayoutLMv2Tokenizer
rust_tokenizer_class = LayoutLMv2TokenizerFast
test_rust_tokenizer = True

View File

@ -49,6 +49,7 @@ logger = logging.get_logger(__name__)
@require_tokenizers
@require_pandas
class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/layoutlmv3-base"
tokenizer_class = LayoutLMv3Tokenizer
rust_tokenizer_class = LayoutLMv3TokenizerFast
test_rust_tokenizer = True

View File

@ -54,6 +54,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_tokenizers
@require_pandas
class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "FacebookAI/xlm-roberta-base"
tokenizer_class = LayoutXLMTokenizer
rust_tokenizer_class = LayoutXLMTokenizerFast
test_rust_tokenizer = True

View File

@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class TestTokenizationLED(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "allenai/led-base-16384"
tokenizer_class = LEDTokenizer
rust_tokenizer_class = LEDTokenizerFast
test_rust_tokenizer = True

View File

@ -52,6 +52,7 @@ if is_torch_available():
@require_sentencepiece
@require_tokenizers
class LlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "hf-internal-testing/llama-tokenizer"
tokenizer_class = LlamaTokenizer
rust_tokenizer_class = LlamaTokenizerFast

View File

@ -30,6 +30,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest with FacebookAI/roberta-base->allenai/longformer-base-4096,Roberta->Longformer,roberta->longformer,
class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "allenai/longformer-base-4096"
# Ignore copy
tokenizer_class = LongformerTokenizer
test_slow_tokenizer = True

View File

@ -28,6 +28,7 @@ SAMPLE_ENTITY_VOCAB = get_tests_dir("fixtures/test_entity_vocab.json")
class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "studio-ousia/luke-base"
tokenizer_class = LukeTokenizer
test_rust_tokenizer = False
from_pretrained_kwargs = {"cls_token": "<s>"}

View File

@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class LxmertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "unc-nlp/lxmert-base-uncased"
tokenizer_class = LxmertTokenizer
rust_tokenizer_class = LxmertTokenizerFast
test_rust_tokenizer = True

View File

@ -48,6 +48,7 @@ FR_CODE = 128028
@require_sentencepiece
class M2M100TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/m2m100_418M"
tokenizer_class = M2M100Tokenizer
test_rust_tokenizer = False
test_seq2seq = False

View File

@ -45,6 +45,7 @@ else:
@require_sentencepiece
class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "Helsinki-NLP/opus-mt-en-de"
tokenizer_class = MarianTokenizer
test_rust_tokenizer = False
test_sentencepiece = True

View File

@ -41,6 +41,7 @@ logger = logging.get_logger(__name__)
@require_tokenizers
class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/markuplm-base"
tokenizer_class = MarkupLMTokenizer
rust_tokenizer_class = MarkupLMTokenizerFast
test_rust_tokenizer = True

View File

@ -41,6 +41,7 @@ RO_CODE = 250020
@require_sentencepiece
@require_tokenizers
class MBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/mbart-large-en-ro"
tokenizer_class = MBartTokenizer
rust_tokenizer_class = MBartTokenizerFast
test_rust_tokenizer = True

View File

@ -41,6 +41,7 @@ RO_CODE = 250020
@require_sentencepiece
@require_tokenizers
class MBart50TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/mbart-large-50-one-to-many-mmt"
tokenizer_class = MBart50Tokenizer
rust_tokenizer_class = MBart50TokenizerFast
test_rust_tokenizer = True

View File

@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class MgpstrTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "alibaba-damo/mgp-str-base"
tokenizer_class = MgpstrTokenizer
test_rust_tokenizer = False
from_pretrained_kwargs = {}

View File

@ -28,6 +28,7 @@ SAMPLE_ENTITY_VOCAB = get_tests_dir("fixtures/test_entity_vocab.json")
class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "studio-ousia/mluke-base"
tokenizer_class = MLukeTokenizer
test_rust_tokenizer = False
from_pretrained_kwargs = {"cls_token": "<s>"}

View File

@ -34,6 +34,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
@require_tokenizers
class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "mobilebert-uncased"
tokenizer_class = MobileBertTokenizer
rust_tokenizer_class = MobileBertTokenizerFast
test_rust_tokenizer = True

View File

@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class MPNetTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/mpnet-base"
tokenizer_class = MPNetTokenizer
rust_tokenizer_class = MPNetTokenizerFast
test_rust_tokenizer = True

View File

@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_roberta_det
@require_tokenizers
class TestTokenizationMvp(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "RUCAIBox/mvp"
tokenizer_class = MvpTokenizer
rust_tokenizer_class = MvpTokenizerFast
test_rust_tokenizer = True

View File

@ -49,6 +49,7 @@ RO_CODE = 256145
@require_sentencepiece
@require_tokenizers
class NllbTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/nllb-200-distilled-600M"
tokenizer_class = NllbTokenizer
rust_tokenizer_class = NllbTokenizerFast
test_rust_tokenizer = True

View File

@ -24,6 +24,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class NougatTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/nougat-base"
slow_tokenizer_class = None
rust_tokenizer_class = NougatTokenizerFast
tokenizer_class = NougatTokenizerFast

View File

@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "openai-community/openai-gpt"
"""Tests OpenAIGPTTokenizer that uses BERT BasicTokenizer."""
tokenizer_class = OpenAIGPTTokenizer

View File

@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_no_bos.model")
@require_sentencepiece
@require_tokenizers
class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/pegasus-xsum"
tokenizer_class = PegasusTokenizer
rust_tokenizer_class = PegasusTokenizerFast
test_rust_tokenizer = True
@ -135,6 +136,7 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@require_sentencepiece
@require_tokenizers
class BigBirdPegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/pegasus-xsum"
tokenizer_class = PegasusTokenizer
rust_tokenizer_class = PegasusTokenizerFast
test_rust_tokenizer = True

View File

@ -36,6 +36,7 @@ else:
class PerceiverTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "deepmind/language-perceiver"
tokenizer_class = PerceiverTokenizer
test_rust_tokenizer = False

View File

@ -22,6 +22,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
class PhobertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "vinai/phobert-base"
tokenizer_class = PhobertTokenizer
test_rust_tokenizer = False

View File

@ -40,6 +40,7 @@ PYTHON_CODE = 50002
@require_sentencepiece
@require_tokenizers
class PLBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "uclanlp/plbart-base"
tokenizer_class = PLBartTokenizer
rust_tokenizer_class = None
test_rust_tokenizer = False

View File

@ -32,6 +32,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
class ProphetNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/prophetnet-large-uncased"
tokenizer_class = ProphetNetTokenizer
test_rust_tokenizer = False

View File

@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class Qwen2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "qwen/qwen-tokenizer"
tokenizer_class = Qwen2Tokenizer
rust_tokenizer_class = Qwen2TokenizerFast
test_slow_tokenizer = True

View File

@ -33,6 +33,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
@require_tokenizers
class RealmTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/realm-cc-news-pretrained-embedder"
tokenizer_class = RealmTokenizer
rust_tokenizer_class = RealmTokenizerFast
test_rust_tokenizer = True

View File

@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_sentencepiece
@require_tokenizers
class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/reformer-crime-and-punishment"
tokenizer_class = ReformerTokenizer
rust_tokenizer_class = ReformerTokenizerFast
test_rust_tokenizer = True

View File

@ -32,6 +32,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_sentencepiece
@require_tokenizers
class RemBertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/rembert"
tokenizer_class = RemBertTokenizer
rust_tokenizer_class = RemBertTokenizerFast
space_between_special_tokens = True

View File

@ -28,6 +28,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "FacebookAI/roberta-base"
tokenizer_class = RobertaTokenizer
rust_tokenizer_class = RobertaTokenizerFast
test_rust_tokenizer = True

View File

@ -34,6 +34,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
@require_tokenizers
class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "weiweishi/roc-bert-base-zh"
tokenizer_class = RoCBertTokenizer
rust_tokenizer_class = None
test_rust_tokenizer = False

View File

@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_rjieba
@require_tokenizers
class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "junnyu/roformer_chinese_small"
tokenizer_class = RoFormerTokenizer
rust_tokenizer_class = RoFormerTokenizerFast
space_between_special_tokens = True

View File

@ -53,6 +53,7 @@ SMALL_TRAINING_CORPUS = [
@require_sentencepiece
@require_tokenizers
class SeamlessM4TTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/hf-seamless-m4t-medium"
tokenizer_class = SeamlessM4TTokenizer
rust_tokenizer_class = SeamlessM4TTokenizerFast
test_rust_tokenizer = True

View File

@ -38,6 +38,7 @@ else:
@require_sentencepiece
@require_tokenizers
class SiglipTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/siglip-base-patch16-224"
tokenizer_class = SiglipTokenizer
test_rust_tokenizer = False
test_sentencepiece = True

View File

@ -37,6 +37,7 @@ ES_CODE = 10
@require_sentencepiece
@require_tokenizers
class SpeechToTextTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/s2t-small-librispeech-asr"
tokenizer_class = Speech2TextTokenizer
test_rust_tokenizer = False
test_sentencepiece = True

View File

@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
class SpeechToTextTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/s2t-wav2vec2-large-en-de"
tokenizer_class = Speech2Text2Tokenizer
test_rust_tokenizer = False

View File

@ -30,6 +30,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_bpe_char.model")
@require_sentencepiece
@require_tokenizers
class SpeechT5TokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/speecht5_asr"
tokenizer_class = SpeechT5Tokenizer
test_rust_tokenizer = False
test_sentencepiece = True

View File

@ -25,6 +25,7 @@ class SqueezeBertTokenizationTest(BertTokenizationTest):
tokenizer_class = SqueezeBertTokenizer
rust_tokenizer_class = SqueezeBertTokenizerFast
test_rust_tokenizer = True
from_pretrained_id = "squeezebert/squeezebert-uncased"
def get_rust_tokenizer(self, **kwargs):
return SqueezeBertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)

View File

@ -38,6 +38,7 @@ else:
@require_sentencepiece
@require_tokenizers
class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google-t5/t5-small"
tokenizer_class = T5Tokenizer
rust_tokenizer_class = T5TokenizerFast
test_rust_tokenizer = True

View File

@ -53,6 +53,7 @@ else:
@require_tokenizers
@require_pandas
class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/tapas-large-finetuned-sqa"
tokenizer_class = TapasTokenizer
test_rust_tokenizer = False
space_between_special_tokens = True

View File

@ -54,6 +54,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_tokenizers
@require_pandas
class UdopTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/udop-large"
tokenizer_class = UdopTokenizer
rust_tokenizer_class = UdopTokenizerFast
test_rust_tokenizer = True

View File

@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
class VitsTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/mms-tts-eng"
tokenizer_class = VitsTokenizer
test_rust_tokenizer = False

View File

@ -367,6 +367,7 @@ class Wav2Vec2TokenizerTest(unittest.TestCase):
class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/wav2vec2-base-960h"
tokenizer_class = Wav2Vec2CTCTokenizer
test_rust_tokenizer = False

View File

@ -28,6 +28,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_phonemizer
class Wav2Vec2PhonemeCTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/wav2vec2-lv-60-espeak-cv-ft"
tokenizer_class = Wav2Vec2PhonemeCTCTokenizer
test_rust_tokenizer = False

View File

@ -31,6 +31,7 @@ NOTIMESTAMPS = 50363
class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "openai/whisper-tiny"
tokenizer_class = WhisperTokenizer
rust_tokenizer_class = WhisperTokenizerFast
test_rust_tokenizer = True

View File

@ -31,6 +31,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_sentencepiece
@require_tokenizers
class XGLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/xglm-564M"
tokenizer_class = XGLMTokenizer
rust_tokenizer_class = XGLMTokenizerFast
test_rust_tokenizer = True

View File

@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
class XLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "FacebookAI/xlm-mlm-en-2048"
tokenizer_class = XLMTokenizer
test_rust_tokenizer = False

View File

@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_sentencepiece
class XLMProphetNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/xprophetnet-large-wiki100-cased"
tokenizer_class = XLMProphetNetTokenizer
test_rust_tokenizer = False
test_sentencepiece = True

View File

@ -31,6 +31,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_sentencepiece
@require_tokenizers
class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "FacebookAI/xlm-roberta-base"
tokenizer_class = XLMRobertaTokenizer
rust_tokenizer_class = XLMRobertaTokenizerFast
test_rust_tokenizer = True

View File

@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_sentencepiece
@require_tokenizers
class XLNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "xlnet/xlnet-base-cased"
tokenizer_class = XLNetTokenizer
rust_tokenizer_class = XLNetTokenizerFast
test_rust_tokenizer = True

View File

@ -186,6 +186,7 @@ class TokenizerTesterMixin:
space_between_special_tokens = False
from_pretrained_kwargs = None
from_pretrained_filter = None
from_pretrained_id = None
from_pretrained_vocab_key = "vocab_file"
test_seq2seq = True
@ -200,19 +201,13 @@ class TokenizerTesterMixin:
# Tokenizer.filter makes it possible to filter which Tokenizer to case based on all the
# information available in Tokenizer (name, rust class, python class, vocab key name)
if self.test_rust_tokenizer:
tokenizers_list = [
self.tokenizers_list = [
(
self.rust_tokenizer_class,
pretrained_name,
self.from_pretrained_id,
self.from_pretrained_kwargs if self.from_pretrained_kwargs is not None else {},
)
for pretrained_name in self.rust_tokenizer_class.pretrained_vocab_files_map[
self.from_pretrained_vocab_key
].keys()
if self.from_pretrained_filter is None
or (self.from_pretrained_filter is not None and self.from_pretrained_filter(pretrained_name))
]
self.tokenizers_list = tokenizers_list[:1] # Let's just test the first pretrained vocab for speed
else:
self.tokenizers_list = []
with open(f"{get_tests_dir()}/fixtures/sample_text.txt", encoding="utf-8") as f_data: