diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 6b9aa256fa3..b4c21a6b0cc 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -118,7 +118,7 @@ from .pipelines import ( # Tokenizers from .tokenization_albert import AlbertTokenizer from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer -from .tokenization_bart import BartTokenizer, MBartTokenizer +from .tokenization_bart import BartTokenizer, BartTokenizerFast, MBartTokenizer from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer from .tokenization_camembert import CamembertTokenizer diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py index e272281d898..e2297aa8d68 100644 --- a/src/transformers/tokenization_bart.py +++ b/src/transformers/tokenization_bart.py @@ -16,7 +16,7 @@ import logging from typing import List, Optional -from .tokenization_roberta import RobertaTokenizer +from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast from .tokenization_utils import BatchEncoding from .tokenization_xlm_roberta import XLMRobertaTokenizer @@ -44,6 +44,15 @@ class BartTokenizer(RobertaTokenizer): } +class BartTokenizerFast(RobertaTokenizerFast): + # merges and vocab same as Roberta + max_model_input_sizes = {m: 1024 for m in _all_bart_models} + pretrained_vocab_files_map = { + "vocab_file": {m: vocab_url for m in _all_bart_models}, + "merges_file": {m: merges_url for m in _all_bart_models}, + } + + _all_mbart_models = ["facebook/mbart-large-en-ro"] SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"