mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Adds pretrained IDs directly in the tests (#29534)
* Adds pretrained IDs directly in the tests * Fix tests * Fix tests * Review!
This commit is contained in:
parent
38bff8c84f
commit
11bbb505c7
@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model")
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class AlbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "albert/albert-base-v1"
|
||||
tokenizer_class = AlbertTokenizer
|
||||
rust_tokenizer_class = AlbertTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_roberta_det
|
||||
|
||||
@require_tokenizers
|
||||
class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "facebook/bart-base"
|
||||
tokenizer_class = BartTokenizer
|
||||
rust_tokenizer_class = BartTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
@require_sentencepiece
|
||||
@slow # see https://github.com/huggingface/transformers/issues/11457
|
||||
class BarthezTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "moussaKam/mbarthez"
|
||||
tokenizer_class = BarthezTokenizer
|
||||
rust_tokenizer_class = BarthezTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -26,6 +26,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_bpe.model")
|
||||
|
||||
|
||||
class BartphoTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "vinai/bartpho-syllable"
|
||||
tokenizer_class = BartphoTokenizer
|
||||
test_rust_tokenizer = False
|
||||
test_sentencepiece = True
|
||||
|
@ -34,6 +34,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
|
||||
|
||||
@require_tokenizers
|
||||
class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "google-bert/bert-base-uncased"
|
||||
tokenizer_class = BertTokenizer
|
||||
rust_tokenizer_class = BertTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -29,6 +29,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
|
||||
@require_sentencepiece
|
||||
class BertGenerationTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "google/bert_for_seq_generation_L-24_bbc_encoder"
|
||||
tokenizer_class = BertGenerationTokenizer
|
||||
test_rust_tokenizer = False
|
||||
test_sentencepiece = True
|
||||
|
@ -36,6 +36,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@custom_tokenizers
|
||||
class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "cl-tohoku/bert-base-japanese"
|
||||
tokenizer_class = BertJapaneseTokenizer
|
||||
test_rust_tokenizer = False
|
||||
space_between_special_tokens = True
|
||||
@ -403,6 +404,7 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
@custom_tokenizers
|
||||
class BertJapaneseCharacterTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "cl-tohoku/bert-base-japanese"
|
||||
tokenizer_class = BertJapaneseTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
|
@ -22,6 +22,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class BertweetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "vinai/bertweet-base"
|
||||
tokenizer_class = BertweetTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
|
@ -30,6 +30,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class BigBirdTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "google/bigbird-roberta-base"
|
||||
tokenizer_class = BigBirdTokenizer
|
||||
rust_tokenizer_class = BigBirdTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_sacremoses
|
||||
class BioGptTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "microsoft/biogpt"
|
||||
tokenizer_class = BioGptTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
|
@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class BlenderbotSmallTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "facebook/blenderbot_small-90M"
|
||||
tokenizer_class = BlenderbotSmallTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
|
@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_tokenizers
|
||||
class BloomTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "bigscience/tokenizer"
|
||||
slow_tokenizer_class = None
|
||||
rust_tokenizer_class = BloomTokenizerFast
|
||||
tokenizer_class = BloomTokenizerFast
|
||||
|
@ -32,6 +32,7 @@ FRAMEWORK = "pt" if is_torch_available() else "tf"
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class CamembertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "almanach/camembert-base"
|
||||
tokenizer_class = CamembertTokenizer
|
||||
rust_tokenizer_class = CamembertTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -28,6 +28,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class CanineTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "nielsr/canine-s"
|
||||
tokenizer_class = CanineTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
|
@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_tokenizers
|
||||
class CLIPTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "openai/clip-vit-base-patch32"
|
||||
tokenizer_class = CLIPTokenizer
|
||||
rust_tokenizer_class = CLIPTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, slow
|
||||
|
||||
|
||||
class ClvpTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "susnato/clvp_dev"
|
||||
tokenizer_class = ClvpTokenizer
|
||||
test_rust_tokenizer = False
|
||||
from_pretrained_kwargs = {"add_prefix_space": True}
|
||||
|
@ -51,6 +51,7 @@ if is_torch_available():
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class CodeLlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "hf-internal-testing/llama-code-tokenizer"
|
||||
tokenizer_class = CodeLlamaTokenizer
|
||||
rust_tokenizer_class = CodeLlamaTokenizerFast
|
||||
test_rust_tokenizer = False
|
||||
|
@ -28,6 +28,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_tokenizers
|
||||
class CodeGenTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "Salesforce/codegen-350M-mono"
|
||||
tokenizer_class = CodeGenTokenizer
|
||||
rust_tokenizer_class = CodeGenTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -24,6 +24,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_jieba
|
||||
class CPMAntTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "openbmb/cpm-ant-10b"
|
||||
tokenizer_class = CpmAntTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
|
@ -23,6 +23,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class CTRLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "Salesforce/ctrl"
|
||||
tokenizer_class = CTRLTokenizer
|
||||
test_rust_tokenizer = False
|
||||
test_seq2seq = False
|
||||
|
@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class DebertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "microsoft/deberta-base"
|
||||
tokenizer_class = DebertaTokenizer
|
||||
test_rust_tokenizer = True
|
||||
rust_tokenizer_class = DebertaTokenizerFast
|
||||
|
@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model")
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class DebertaV2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "microsoft/deberta-v2-xlarge"
|
||||
tokenizer_class = DebertaV2Tokenizer
|
||||
rust_tokenizer_class = DebertaV2TokenizerFast
|
||||
test_sentencepiece = True
|
||||
|
@ -25,6 +25,7 @@ class DistilBertTokenizationTest(BertTokenizationTest):
|
||||
tokenizer_class = DistilBertTokenizer
|
||||
rust_tokenizer_class = DistilBertTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
from_pretrained_id = "distilbert/distilbert-base-uncased"
|
||||
|
||||
@slow
|
||||
def test_sequence_builders(self):
|
||||
|
@ -33,6 +33,7 @@ class DPRContextEncoderTokenizationTest(BertTokenizationTest):
|
||||
tokenizer_class = DPRContextEncoderTokenizer
|
||||
rust_tokenizer_class = DPRContextEncoderTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
from_pretrained_id = "facebook/dpr-ctx_encoder-single-nq-base"
|
||||
|
||||
|
||||
@require_tokenizers
|
||||
@ -40,6 +41,7 @@ class DPRQuestionEncoderTokenizationTest(BertTokenizationTest):
|
||||
tokenizer_class = DPRQuestionEncoderTokenizer
|
||||
rust_tokenizer_class = DPRQuestionEncoderTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
from_pretrained_id = "facebook/dpr-ctx_encoder-single-nq-base"
|
||||
|
||||
|
||||
@require_tokenizers
|
||||
@ -47,6 +49,7 @@ class DPRReaderTokenizationTest(BertTokenizationTest):
|
||||
tokenizer_class = DPRReaderTokenizer
|
||||
rust_tokenizer_class = DPRReaderTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
from_pretrained_id = "facebook/dpr-ctx_encoder-single-nq-base"
|
||||
|
||||
@slow
|
||||
def test_decode_best_spans(self):
|
||||
|
@ -33,6 +33,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
|
||||
|
||||
@require_tokenizers
|
||||
class ElectraTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "google/electra-small-generator"
|
||||
tokenizer_class = ElectraTokenizer
|
||||
rust_tokenizer_class = ElectraTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -28,6 +28,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model")
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class ErnieMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "susnato/ernie-m-base_pytorch"
|
||||
tokenizer_class = ErnieMTokenizer
|
||||
test_seq2seq = False
|
||||
test_sentencepiece = True
|
||||
|
@ -24,6 +24,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_g2p_en
|
||||
class FastSpeech2ConformerTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "espnet/fastspeech2_conformer"
|
||||
tokenizer_class = FastSpeech2ConformerTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
|
@ -28,6 +28,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model")
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class FNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "google/fnet-base"
|
||||
tokenizer_class = FNetTokenizer
|
||||
rust_tokenizer_class = FNetTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -30,6 +30,7 @@ FSMT_TINY2 = "stas/tiny-wmt19-en-ru"
|
||||
|
||||
|
||||
class FSMTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "stas/tiny-wmt19-en-de"
|
||||
tokenizer_class = FSMTTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
|
@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_tokenizers
|
||||
class FunnelTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "funnel-transformer/small"
|
||||
tokenizer_class = FunnelTokenizer
|
||||
rust_tokenizer_class = FunnelTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -49,6 +49,7 @@ if is_torch_available():
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class GemmaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "google/gemma-7b"
|
||||
tokenizer_class = GemmaTokenizer
|
||||
rust_tokenizer_class = GemmaTokenizerFast
|
||||
|
||||
|
@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_tokenizers
|
||||
class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "openai-community/gpt2"
|
||||
tokenizer_class = GPT2Tokenizer
|
||||
rust_tokenizer_class = GPT2TokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -29,6 +29,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_tokenizers
|
||||
class GPTNeoXJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "abeja/gpt-neox-japanese-2.7b"
|
||||
tokenizer_class = GPTNeoXJapaneseTokenizer
|
||||
test_rust_tokenizer = False
|
||||
from_pretrained_kwargs = {"do_clean_text": False, "add_prefix_space": False}
|
||||
|
@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_with_bytefallback.mode
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class GPTSw3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "AI-Sweden-Models/gpt-sw3-126m"
|
||||
tokenizer_class = GPTSw3Tokenizer
|
||||
test_rust_tokenizer = False
|
||||
test_sentencepiece = True
|
||||
|
@ -29,6 +29,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_tokenizers
|
||||
class GPTSanJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "Tanrei/GPTSAN-japanese"
|
||||
tokenizer_class = GPTSanJapaneseTokenizer
|
||||
test_rust_tokenizer = False
|
||||
from_pretrained_kwargs = {"do_clean_text": False, "add_prefix_space": False}
|
||||
|
@ -28,6 +28,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
@require_sacremoses
|
||||
@require_tokenizers
|
||||
class HerbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "allegro/herbert-base-cased"
|
||||
tokenizer_class = HerbertTokenizer
|
||||
rust_tokenizer_class = HerbertTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_tokenizers
|
||||
class LayoutLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "microsoft/layoutlm-base-uncased"
|
||||
tokenizer_class = LayoutLMTokenizer
|
||||
rust_tokenizer_class = LayoutLMTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -61,6 +61,7 @@ logger = logging.get_logger(__name__)
|
||||
@require_tokenizers
|
||||
@require_pandas
|
||||
class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "microsoft/layoutlmv2-base-uncased"
|
||||
tokenizer_class = LayoutLMv2Tokenizer
|
||||
rust_tokenizer_class = LayoutLMv2TokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -49,6 +49,7 @@ logger = logging.get_logger(__name__)
|
||||
@require_tokenizers
|
||||
@require_pandas
|
||||
class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "microsoft/layoutlmv3-base"
|
||||
tokenizer_class = LayoutLMv3Tokenizer
|
||||
rust_tokenizer_class = LayoutLMv3TokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -54,6 +54,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
@require_tokenizers
|
||||
@require_pandas
|
||||
class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "FacebookAI/xlm-roberta-base"
|
||||
tokenizer_class = LayoutXLMTokenizer
|
||||
rust_tokenizer_class = LayoutXLMTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_tokenizers
|
||||
class TestTokenizationLED(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "allenai/led-base-16384"
|
||||
tokenizer_class = LEDTokenizer
|
||||
rust_tokenizer_class = LEDTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -52,6 +52,7 @@ if is_torch_available():
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class LlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "hf-internal-testing/llama-tokenizer"
|
||||
tokenizer_class = LlamaTokenizer
|
||||
rust_tokenizer_class = LlamaTokenizerFast
|
||||
|
||||
|
@ -30,6 +30,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
@require_tokenizers
|
||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest with FacebookAI/roberta-base->allenai/longformer-base-4096,Roberta->Longformer,roberta->longformer,
|
||||
class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "allenai/longformer-base-4096"
|
||||
# Ignore copy
|
||||
tokenizer_class = LongformerTokenizer
|
||||
test_slow_tokenizer = True
|
||||
|
@ -28,6 +28,7 @@ SAMPLE_ENTITY_VOCAB = get_tests_dir("fixtures/test_entity_vocab.json")
|
||||
|
||||
|
||||
class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "studio-ousia/luke-base"
|
||||
tokenizer_class = LukeTokenizer
|
||||
test_rust_tokenizer = False
|
||||
from_pretrained_kwargs = {"cls_token": "<s>"}
|
||||
|
@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_tokenizers
|
||||
class LxmertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "unc-nlp/lxmert-base-uncased"
|
||||
tokenizer_class = LxmertTokenizer
|
||||
rust_tokenizer_class = LxmertTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -48,6 +48,7 @@ FR_CODE = 128028
|
||||
|
||||
@require_sentencepiece
|
||||
class M2M100TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "facebook/m2m100_418M"
|
||||
tokenizer_class = M2M100Tokenizer
|
||||
test_rust_tokenizer = False
|
||||
test_seq2seq = False
|
||||
|
@ -45,6 +45,7 @@ else:
|
||||
|
||||
@require_sentencepiece
|
||||
class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "Helsinki-NLP/opus-mt-en-de"
|
||||
tokenizer_class = MarianTokenizer
|
||||
test_rust_tokenizer = False
|
||||
test_sentencepiece = True
|
||||
|
@ -41,6 +41,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
@require_tokenizers
|
||||
class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "microsoft/markuplm-base"
|
||||
tokenizer_class = MarkupLMTokenizer
|
||||
rust_tokenizer_class = MarkupLMTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -41,6 +41,7 @@ RO_CODE = 250020
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class MBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "facebook/mbart-large-en-ro"
|
||||
tokenizer_class = MBartTokenizer
|
||||
rust_tokenizer_class = MBartTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -41,6 +41,7 @@ RO_CODE = 250020
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class MBart50TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "facebook/mbart-large-50-one-to-many-mmt"
|
||||
tokenizer_class = MBart50Tokenizer
|
||||
rust_tokenizer_class = MBart50TokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_tokenizers
|
||||
class MgpstrTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "alibaba-damo/mgp-str-base"
|
||||
tokenizer_class = MgpstrTokenizer
|
||||
test_rust_tokenizer = False
|
||||
from_pretrained_kwargs = {}
|
||||
|
@ -28,6 +28,7 @@ SAMPLE_ENTITY_VOCAB = get_tests_dir("fixtures/test_entity_vocab.json")
|
||||
|
||||
|
||||
class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "studio-ousia/mluke-base"
|
||||
tokenizer_class = MLukeTokenizer
|
||||
test_rust_tokenizer = False
|
||||
from_pretrained_kwargs = {"cls_token": "<s>"}
|
||||
|
@ -34,6 +34,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
|
||||
|
||||
@require_tokenizers
|
||||
class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "mobilebert-uncased"
|
||||
tokenizer_class = MobileBertTokenizer
|
||||
rust_tokenizer_class = MobileBertTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_tokenizers
|
||||
class MPNetTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "microsoft/mpnet-base"
|
||||
tokenizer_class = MPNetTokenizer
|
||||
rust_tokenizer_class = MPNetTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_roberta_det
|
||||
|
||||
@require_tokenizers
|
||||
class TestTokenizationMvp(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "RUCAIBox/mvp"
|
||||
tokenizer_class = MvpTokenizer
|
||||
rust_tokenizer_class = MvpTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -49,6 +49,7 @@ RO_CODE = 256145
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class NllbTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "facebook/nllb-200-distilled-600M"
|
||||
tokenizer_class = NllbTokenizer
|
||||
rust_tokenizer_class = NllbTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -24,6 +24,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_tokenizers
|
||||
class NougatTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "facebook/nougat-base"
|
||||
slow_tokenizer_class = None
|
||||
rust_tokenizer_class = NougatTokenizerFast
|
||||
tokenizer_class = NougatTokenizerFast
|
||||
|
@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_tokenizers
|
||||
class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "openai-community/openai-gpt"
|
||||
"""Tests OpenAIGPTTokenizer that uses BERT BasicTokenizer."""
|
||||
|
||||
tokenizer_class = OpenAIGPTTokenizer
|
||||
|
@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_no_bos.model")
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "google/pegasus-xsum"
|
||||
tokenizer_class = PegasusTokenizer
|
||||
rust_tokenizer_class = PegasusTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
@ -135,6 +136,7 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class BigBirdPegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "google/pegasus-xsum"
|
||||
tokenizer_class = PegasusTokenizer
|
||||
rust_tokenizer_class = PegasusTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -36,6 +36,7 @@ else:
|
||||
|
||||
|
||||
class PerceiverTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "deepmind/language-perceiver"
|
||||
tokenizer_class = PerceiverTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
|
@ -22,6 +22,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class PhobertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "vinai/phobert-base"
|
||||
tokenizer_class = PhobertTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
|
@ -40,6 +40,7 @@ PYTHON_CODE = 50002
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class PLBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "uclanlp/plbart-base"
|
||||
tokenizer_class = PLBartTokenizer
|
||||
rust_tokenizer_class = None
|
||||
test_rust_tokenizer = False
|
||||
|
@ -32,6 +32,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class ProphetNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "microsoft/prophetnet-large-uncased"
|
||||
tokenizer_class = ProphetNetTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
|
@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_tokenizers
|
||||
class Qwen2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "qwen/qwen-tokenizer"
|
||||
tokenizer_class = Qwen2Tokenizer
|
||||
rust_tokenizer_class = Qwen2TokenizerFast
|
||||
test_slow_tokenizer = True
|
||||
|
@ -33,6 +33,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
|
||||
|
||||
@require_tokenizers
|
||||
class RealmTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "google/realm-cc-news-pretrained-embedder"
|
||||
tokenizer_class = RealmTokenizer
|
||||
rust_tokenizer_class = RealmTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "google/reformer-crime-and-punishment"
|
||||
tokenizer_class = ReformerTokenizer
|
||||
rust_tokenizer_class = ReformerTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -32,6 +32,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class RemBertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "google/rembert"
|
||||
tokenizer_class = RemBertTokenizer
|
||||
rust_tokenizer_class = RemBertTokenizerFast
|
||||
space_between_special_tokens = True
|
||||
|
@ -28,6 +28,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_tokenizers
|
||||
class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "FacebookAI/roberta-base"
|
||||
tokenizer_class = RobertaTokenizer
|
||||
rust_tokenizer_class = RobertaTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -34,6 +34,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
|
||||
|
||||
@require_tokenizers
|
||||
class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "weiweishi/roc-bert-base-zh"
|
||||
tokenizer_class = RoCBertTokenizer
|
||||
rust_tokenizer_class = None
|
||||
test_rust_tokenizer = False
|
||||
|
@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
@require_rjieba
|
||||
@require_tokenizers
|
||||
class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "junnyu/roformer_chinese_small"
|
||||
tokenizer_class = RoFormerTokenizer
|
||||
rust_tokenizer_class = RoFormerTokenizerFast
|
||||
space_between_special_tokens = True
|
||||
|
@ -53,6 +53,7 @@ SMALL_TRAINING_CORPUS = [
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class SeamlessM4TTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "facebook/hf-seamless-m4t-medium"
|
||||
tokenizer_class = SeamlessM4TTokenizer
|
||||
rust_tokenizer_class = SeamlessM4TTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -38,6 +38,7 @@ else:
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class SiglipTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "google/siglip-base-patch16-224"
|
||||
tokenizer_class = SiglipTokenizer
|
||||
test_rust_tokenizer = False
|
||||
test_sentencepiece = True
|
||||
|
@ -37,6 +37,7 @@ ES_CODE = 10
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class SpeechToTextTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "facebook/s2t-small-librispeech-asr"
|
||||
tokenizer_class = Speech2TextTokenizer
|
||||
test_rust_tokenizer = False
|
||||
test_sentencepiece = True
|
||||
|
@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class SpeechToTextTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "facebook/s2t-wav2vec2-large-en-de"
|
||||
tokenizer_class = Speech2Text2Tokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
|
@ -30,6 +30,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_bpe_char.model")
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class SpeechT5TokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "microsoft/speecht5_asr"
|
||||
tokenizer_class = SpeechT5Tokenizer
|
||||
test_rust_tokenizer = False
|
||||
test_sentencepiece = True
|
||||
|
@ -25,6 +25,7 @@ class SqueezeBertTokenizationTest(BertTokenizationTest):
|
||||
tokenizer_class = SqueezeBertTokenizer
|
||||
rust_tokenizer_class = SqueezeBertTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
from_pretrained_id = "squeezebert/squeezebert-uncased"
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
return SqueezeBertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
@ -38,6 +38,7 @@ else:
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "google-t5/t5-small"
|
||||
tokenizer_class = T5Tokenizer
|
||||
rust_tokenizer_class = T5TokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -53,6 +53,7 @@ else:
|
||||
@require_tokenizers
|
||||
@require_pandas
|
||||
class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "google/tapas-large-finetuned-sqa"
|
||||
tokenizer_class = TapasTokenizer
|
||||
test_rust_tokenizer = False
|
||||
space_between_special_tokens = True
|
||||
|
@ -54,6 +54,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
@require_tokenizers
|
||||
@require_pandas
|
||||
class UdopTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "microsoft/udop-large"
|
||||
tokenizer_class = UdopTokenizer
|
||||
rust_tokenizer_class = UdopTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class VitsTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "facebook/mms-tts-eng"
|
||||
tokenizer_class = VitsTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
|
@ -367,6 +367,7 @@ class Wav2Vec2TokenizerTest(unittest.TestCase):
|
||||
|
||||
|
||||
class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "facebook/wav2vec2-base-960h"
|
||||
tokenizer_class = Wav2Vec2CTCTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
|
@ -28,6 +28,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@require_phonemizer
|
||||
class Wav2Vec2PhonemeCTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "facebook/wav2vec2-lv-60-espeak-cv-ft"
|
||||
tokenizer_class = Wav2Vec2PhonemeCTCTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
|
@ -31,6 +31,7 @@ NOTIMESTAMPS = 50363
|
||||
|
||||
|
||||
class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "openai/whisper-tiny"
|
||||
tokenizer_class = WhisperTokenizer
|
||||
rust_tokenizer_class = WhisperTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -31,6 +31,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class XGLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "facebook/xglm-564M"
|
||||
tokenizer_class = XGLMTokenizer
|
||||
rust_tokenizer_class = XGLMTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class XLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "FacebookAI/xlm-mlm-en-2048"
|
||||
tokenizer_class = XLMTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
|
@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
|
||||
@require_sentencepiece
|
||||
class XLMProphetNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "microsoft/xprophetnet-large-wiki100-cased"
|
||||
tokenizer_class = XLMProphetNetTokenizer
|
||||
test_rust_tokenizer = False
|
||||
test_sentencepiece = True
|
||||
|
@ -31,6 +31,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "FacebookAI/xlm-roberta-base"
|
||||
tokenizer_class = XLMRobertaTokenizer
|
||||
rust_tokenizer_class = XLMRobertaTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class XLNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "xlnet/xlnet-base-cased"
|
||||
tokenizer_class = XLNetTokenizer
|
||||
rust_tokenizer_class = XLNetTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
@ -186,6 +186,7 @@ class TokenizerTesterMixin:
|
||||
space_between_special_tokens = False
|
||||
from_pretrained_kwargs = None
|
||||
from_pretrained_filter = None
|
||||
from_pretrained_id = None
|
||||
from_pretrained_vocab_key = "vocab_file"
|
||||
test_seq2seq = True
|
||||
|
||||
@ -200,19 +201,13 @@ class TokenizerTesterMixin:
|
||||
# Tokenizer.filter makes it possible to filter which Tokenizer to case based on all the
|
||||
# information available in Tokenizer (name, rust class, python class, vocab key name)
|
||||
if self.test_rust_tokenizer:
|
||||
tokenizers_list = [
|
||||
self.tokenizers_list = [
|
||||
(
|
||||
self.rust_tokenizer_class,
|
||||
pretrained_name,
|
||||
self.from_pretrained_id,
|
||||
self.from_pretrained_kwargs if self.from_pretrained_kwargs is not None else {},
|
||||
)
|
||||
for pretrained_name in self.rust_tokenizer_class.pretrained_vocab_files_map[
|
||||
self.from_pretrained_vocab_key
|
||||
].keys()
|
||||
if self.from_pretrained_filter is None
|
||||
or (self.from_pretrained_filter is not None and self.from_pretrained_filter(pretrained_name))
|
||||
]
|
||||
self.tokenizers_list = tokenizers_list[:1] # Let's just test the first pretrained vocab for speed
|
||||
else:
|
||||
self.tokenizers_list = []
|
||||
with open(f"{get_tests_dir()}/fixtures/sample_text.txt", encoding="utf-8") as f_data:
|
||||
|
Loading…
Reference in New Issue
Block a user