diff --git a/.circleci/config.yml b/.circleci/config.yml index 01e6d82b334..9a81eea902e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -70,6 +70,27 @@ jobs: - run: sudo pip install pytest codecov pytest-cov - run: python -m pytest -sv ./transformers/tests/ --cov - run: codecov + build_py3_custom_tokenizers: + working_directory: ~/transformers + docker: + - image: circleci/python:3.5 + steps: + - checkout + - run: sudo pip install --progress-bar off . + - run: sudo pip install pytest + - run: sudo pip install mecab-python3 + - run: RUN_CUSTOM_TOKENIZERS=1 python -m pytest -sv ./transformers/tests/tokenization_bert_japanese_test.py + build_py2_custom_tokenizers: + working_directory: ~/transformers + docker: + - image: circleci/python:2.7 + steps: + - checkout + - run: sudo pip install --progress-bar off . + - run: sudo pip install pytest + - run: sudo apt-get -y install libmecab-dev mecab mecab-ipadic-utf8 swig + - run: sudo pip install mecab-python + - run: RUN_CUSTOM_TOKENIZERS=1 python -m pytest -sv ./transformers/tests/tokenization_bert_japanese_test.py deploy_doc: working_directory: ~/transformers docker: @@ -82,6 +103,16 @@ jobs: - run: sudo pip install --progress-bar off -r docs/requirements.txt - run: sudo pip install --progress-bar off -r requirements.txt - run: ./.circleci/deploy.sh + repository_consistency: + working_directory: ~/transformers + docker: + - image: circleci/python:3.5 + resource_class: small + parallelism: 1 + steps: + - checkout + - run: sudo pip install requests + - run: python ./utils/link_tester.py workflow_filters: &workflow_filters filters: branches: @@ -91,6 +122,9 @@ workflows: version: 2 build_and_test: jobs: + - repository_consistency + - build_py3_custom_tokenizers + - build_py2_custom_tokenizers - build_py3_torch_and_tf - build_py3_torch - build_py3_tf diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index 8ca7ba0d406..7e1366b53a4 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -61,6 +61,24 @@ Here is the full list of the currently provided pretrained models together with | | ``bert-base-german-dbmdz-uncased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | | | | | Trained on uncased German text by DBMDZ | | | | (see `details on dbmdz repository `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-japanese`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on Japanese text. Text is tokenized with MeCab and WordPiece. | +| | | | `MeCab `__ is required for tokenization. | +| | | (see `details on cl-tohoku repository `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-japanese-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on Japanese text using Whole-Word-Masking. Text is tokenized with MeCab and WordPiece. | +| | | | `MeCab `__ is required for tokenization. | +| | | (see `details on cl-tohoku repository `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-japanese-char`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on Japanese text. Text is tokenized into characters. | +| | | (see `details on cl-tohoku repository `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-japanese-char-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on Japanese text using Whole-Word-Masking. Text is tokenized into characters. | +| | | (see `details on cl-tohoku repository `__). | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | GPT | ``openai-gpt`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | | | | | OpenAI GPT English model | @@ -169,35 +187,35 @@ Here is the full list of the currently provided pretrained models together with +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | ALBERT | ``albert-base-v1`` | | 12 repeating layers, 128 embedding, 768-hidden, 12-heads, 11M parameters | | | | | ALBERT base model | -| | | (see `details `__) | +| | | (see `details `__) | | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | | ``albert-large-v1`` | | 24 repeating layers, 128 embedding, 1024-hidden, 16-heads, 17M parameters | | | | | ALBERT large model | -| | | (see `details `__) | +| | | (see `details `__) | | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | | ``albert-xlarge-v1`` | | 24 repeating layers, 128 embedding, 2048-hidden, 16-heads, 58M parameters | | | | | ALBERT xlarge model | -| | | (see `details `__) | +| | | (see `details `__) | | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | | ``albert-xxlarge-v1`` | | 12 repeating layer, 128 embedding, 4096-hidden, 64-heads, 223M parameters | | | | | ALBERT xxlarge model | -| | | (see `details `__) | +| | | (see `details `__) | | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | | ``albert-base-v2`` | | 12 repeating layers, 128 embedding, 768-hidden, 12-heads, 11M parameters | | | | | ALBERT base model with no dropout, additional training data and longer training | -| | | (see `details `__) | +| | | (see `details `__) | | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | | ``albert-large-v2`` | | 24 repeating layers, 128 embedding, 1024-hidden, 16-heads, 17M parameters | | | | | ALBERT large model with no dropout, additional training data and longer training | -| | | (see `details `__) | +| | | (see `details `__) | | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | | ``albert-xlarge-v2`` | | 24 repeating layers, 128 embedding, 2048-hidden, 16-heads, 58M parameters | | | | | ALBERT xlarge model with no dropout, additional training data and longer training | -| | | (see `details `__) | +| | | (see `details `__) | | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | | ``albert-xxlarge-v2`` | | 12 repeating layer, 128 embedding, 4096-hidden, 64-heads, 223M parameters | | | | | ALBERT xxlarge model with no dropout, additional training data and longer training | -| | | (see `details `__) | +| | | (see `details `__) | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | T5 | ``t5-small`` | | 6-layer, 768-hidden, 12-heads, 66M parameters | | | | | The DistilBERT model distilled from the BERT model `bert-base-uncased` checkpoint | diff --git a/examples/run_glue.py b/examples/run_glue.py index 369a7110ab5..1a51255c110 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -380,7 +380,7 @@ def main(): parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.0, type=float, - help="Weight deay if we apply some.") + help="Weight decay if we apply some.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--max_grad_norm", default=1.0, type=float, diff --git a/examples/run_squad.py b/examples/run_squad.py index 2df29014ef6..117b86e32cd 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -299,10 +299,13 @@ def evaluate(args, model, tokenizer, prefix=""): # XLNet and XLM use a more complex post-processing procedure if args.model_type in ['xlnet', 'xlm']: + start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top + end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top + predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size, args.max_answer_length, output_prediction_file, output_nbest_file, output_null_log_odds_file, - model.config.start_n_top, model.config.end_n_top, + start_n_top, end_n_top, args.version_2_with_negative, tokenizer, args.verbose_logging) else: predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size, @@ -334,7 +337,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal else: logger.info("Creating features from dataset file at %s", input_dir) - if not args.data_dir: + if not args.data_dir and ((evaluate and not args.predict_file) or (not evaluate and not args.train_file)): try: import tensorflow_datasets as tfds except ImportError: @@ -347,7 +350,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal examples = SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=evaluate) else: processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor() - examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) + + if evaluate: + examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file) + else: + examples = processor.get_train_examples(args.data_dir, filename=args.train_file) features, dataset = squad_convert_examples_to_features( examples=examples, @@ -384,7 +391,14 @@ def main(): ## Other parameters parser.add_argument("--data_dir", default=None, type=str, - help="The input data dir. Should contain the .json files for the task. If not specified, will run with tensorflow_datasets.") + help="The input data dir. Should contain the .json files for the task." + + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.") + parser.add_argument("--train_file", default=None, type=str, + help="The input training file. If a data dir is specified, will look for the file there" + + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.") + parser.add_argument("--predict_file", default=None, type=str, + help="The input evaluation file. If a data dir is specified, will look for the file there" + + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.") parser.add_argument("--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name") parser.add_argument("--tokenizer_name", default="", type=str, @@ -469,11 +483,6 @@ def main(): parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") args = parser.parse_args() - args.predict_file = os.path.join(args.output_dir, 'predictions_{}_{}.txt'.format( - list(filter(None, args.model_name_or_path.split('/'))).pop(), - str(args.max_seq_length)) - ) - if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) diff --git a/transformers/__init__.py b/transformers/__init__.py index 3133101eb78..f456c859aa4 100644 --- a/transformers/__init__.py +++ b/transformers/__init__.py @@ -37,6 +37,7 @@ if is_sklearn_available(): from .tokenization_utils import (PreTrainedTokenizer) from .tokenization_auto import AutoTokenizer from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer +from .tokenization_bert_japanese import BertJapaneseTokenizer, MecabTokenizer, CharacterTokenizer from .tokenization_openai import OpenAIGPTTokenizer from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) from .tokenization_gpt2 import GPT2Tokenizer diff --git a/transformers/configuration_auto.py b/transformers/configuration_auto.py index 477df8981ce..680c55fa54c 100644 --- a/transformers/configuration_auto.py +++ b/transformers/configuration_auto.py @@ -85,6 +85,7 @@ class AutoConfig(object): pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing a configuration file saved using the :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``. diff --git a/transformers/configuration_bert.py b/transformers/configuration_bert.py index d63be963eba..01fcd88cb81 100644 --- a/transformers/configuration_bert.py +++ b/transformers/configuration_bert.py @@ -42,6 +42,10 @@ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json", 'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json", + 'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json", + 'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json", + 'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json", + 'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json" } diff --git a/transformers/configuration_utils.py b/transformers/configuration_utils.py index 08cee75d81b..82959adb576 100644 --- a/transformers/configuration_utils.py +++ b/transformers/configuration_utils.py @@ -24,7 +24,7 @@ import logging import os from io import open -from .file_utils import cached_path, CONFIG_NAME +from .file_utils import CONFIG_NAME, cached_path, is_remote_url, hf_bucket_url logger = logging.getLogger(__name__) @@ -79,6 +79,7 @@ class PretrainedConfig(object): pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing a configuration file saved using the :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``. @@ -131,8 +132,10 @@ class PretrainedConfig(object): config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] elif os.path.isdir(pretrained_model_name_or_path): config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) - else: + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path + else: + config_file = hf_bucket_url(pretrained_model_name_or_path, postfix=CONFIG_NAME) # redirect to the cache, if necessary try: resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, @@ -187,7 +190,7 @@ class PretrainedConfig(object): @classmethod def from_json_file(cls, json_file): - """Constructs a `BertConfig` from a json file of parameters.""" + """Constructs a `Config` from a json file of parameters.""" with open(json_file, "r", encoding='utf-8') as reader: text = reader.read() return cls.from_dict(json.loads(text)) diff --git a/transformers/data/metrics/squad_metrics.py b/transformers/data/metrics/squad_metrics.py index 0755c0ab7a8..7b03255f496 100644 --- a/transformers/data/metrics/squad_metrics.py +++ b/transformers/data/metrics/squad_metrics.py @@ -695,7 +695,12 @@ def compute_predictions_log_probs( tok_text = " ".join(tok_text.split()) orig_text = " ".join(orig_tokens) - final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case, + if hasattr(tokenizer, "do_lower_case"): + do_lower_case = tokenizer.do_lower_case + else: + do_lower_case = tokenizer.do_lowercase_and_remove_accent + + final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging) if final_text in seen_predictions: diff --git a/transformers/data/processors/squad.py b/transformers/data/processors/squad.py index 3d7f8325406..9bc43756842 100644 --- a/transformers/data/processors/squad.py +++ b/transformers/data/processors/squad.py @@ -373,6 +373,9 @@ class SquadProcessor(DataProcessor): which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively. """ + if data_dir is None: + data_dir = "" + if self.train_file is None: raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") @@ -389,6 +392,9 @@ class SquadProcessor(DataProcessor): filename: None by default, specify this if the evaluation file has a different name than the original one which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively. """ + if data_dir is None: + data_dir = "" + if self.dev_file is None: raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") diff --git a/transformers/file_utils.py b/transformers/file_utils.py index e36bbf4eeb6..03b2fdb9f40 100644 --- a/transformers/file_utils.py +++ b/transformers/file_utils.py @@ -21,7 +21,7 @@ import boto3 from botocore.config import Config from botocore.exceptions import ClientError import requests -from tqdm import tqdm +from tqdm.auto import tqdm from contextlib import contextmanager logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -73,9 +73,13 @@ TF2_WEIGHTS_NAME = 'tf_model.h5' TF_WEIGHTS_NAME = 'model.ckpt' CONFIG_NAME = "config.json" + DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]] +S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" + + def is_torch_available(): return _torch_available @@ -106,6 +110,18 @@ else: return fn return docstring_decorator + +def is_remote_url(url_or_filename): + parsed = urlparse(url_or_filename) + return parsed.scheme in ('http', 'https', 's3') + +def hf_bucket_url(identifier, postfix=None): + if postfix is None: + return "/".join((S3_BUCKET_PREFIX, identifier)) + else: + return "/".join((S3_BUCKET_PREFIX, identifier, postfix)) + + def url_to_filename(url, etag=None): """ Convert `url` into a hashed filename in a repeatable way. @@ -174,9 +190,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N if sys.version_info[0] == 3 and isinstance(cache_dir, Path): cache_dir = str(cache_dir) - parsed = urlparse(url_or_filename) - - if parsed.scheme in ('http', 'https', 's3'): + if is_remote_url(url_or_filename): # URL, so get it from the cache (downloading if necessary) return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, @@ -184,7 +198,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N elif os.path.exists(url_or_filename): # File, and it exists. return url_or_filename - elif parsed.scheme == '': + elif urlparse(url_or_filename).scheme == '': # File, but it doesn't exist. raise EnvironmentError("file {} not found".format(url_or_filename)) else: @@ -248,7 +262,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0): return content_length = response.headers.get('Content-Length') total = resume_size + int(content_length) if content_length is not None else None - progress = tqdm(unit="B", total=total, initial=resume_size) + progress = tqdm(unit="B", unit_scale=True, total=total, initial=resume_size, desc="Downloading") for chunk in response.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks progress.update(len(chunk)) diff --git a/transformers/modeling_auto.py b/transformers/modeling_auto.py index b0b42533ac7..19a54cca864 100644 --- a/transformers/modeling_auto.py +++ b/transformers/modeling_auto.py @@ -28,7 +28,6 @@ from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassifica from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertForSequenceClassification, CamembertForMultipleChoice -from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertForSequenceClassification, CamembertForMultipleChoice from .modeling_albert import AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, AlbertForQuestionAnswering from .modeling_t5 import T5Model, T5WithLMHeadModel @@ -97,6 +96,7 @@ class AutoModel(object): pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. @@ -239,6 +239,7 @@ class AutoModelWithLMHead(object): pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. @@ -370,6 +371,7 @@ class AutoModelForSequenceClassification(object): pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. @@ -488,6 +490,7 @@ class AutoModelForQuestionAnswering(object): pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index d84b0a1a7cc..d0f35272ac0 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -48,6 +48,10 @@ BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin", 'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-pytorch_model.bin", 'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-pytorch_model.bin", + 'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-pytorch_model.bin", + 'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-pytorch_model.bin", + 'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-pytorch_model.bin", + 'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-pytorch_model.bin" } @@ -1233,9 +1237,9 @@ class BertForQuestionAnswering(BertPreTrainedModel): question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" input_text = "[CLS] " + question + " [SEP] " + text + " [SEP]" input_ids = tokenizer.encode(input_text) - token_type_ids = [0 if i <= input_ids.index(102) else 1 for i in range(len(input_ids))] + token_type_ids = [0 if i <= input_ids.index(102) else 1 for i in range(len(input_ids))] start_scores, end_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([token_type_ids])) - all_tokens = tokenizer.convert_ids_to_tokens(input_ids) + all_tokens = tokenizer.convert_ids_to_tokens(input_ids) print(' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])) # a nice puppet diff --git a/transformers/modeling_encoder_decoder.py b/transformers/modeling_encoder_decoder.py index 713cf5252e6..a91c046d8f4 100644 --- a/transformers/modeling_encoder_decoder.py +++ b/transformers/modeling_encoder_decoder.py @@ -59,12 +59,14 @@ class PreTrainedEncoderDecoder(nn.Module): encoder_pretrained_model_name_or_path: information necessary to initiate the encoder. Either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/encoder``. - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. decoder_pretrained_model_name_or_path: information necessary to initiate the decoder. Either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/decoder``. - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. diff --git a/transformers/modeling_tf_auto.py b/transformers/modeling_tf_auto.py index fc8f8d8ab12..b4ff660098e 100644 --- a/transformers/modeling_tf_auto.py +++ b/transformers/modeling_tf_auto.py @@ -84,6 +84,7 @@ class TFAutoModel(object): pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - a path or url to a `PyTorch, TF 1.X or TF 2.0 checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In the case of a PyTorch checkpoint, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. @@ -219,6 +220,7 @@ class TFAutoModelWithLMHead(object): pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - a path or url to a `PyTorch, TF 1.X or TF 2.0 checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In the case of a PyTorch checkpoint, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. @@ -347,6 +349,7 @@ class TFAutoModelForSequenceClassification(object): pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - a path or url to a `PyTorch, TF 1.X or TF 2.0 checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In the case of a PyTorch checkpoint, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. @@ -462,6 +465,7 @@ class TFAutoModelForQuestionAnswering(object): pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - a path or url to a `PyTorch, TF 1.X or TF 2.0 checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In the case of a PyTorch checkpoint, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. diff --git a/transformers/modeling_tf_bert.py b/transformers/modeling_tf_bert.py index 5aa7bb3da26..7cc71f50633 100644 --- a/transformers/modeling_tf_bert.py +++ b/transformers/modeling_tf_bert.py @@ -48,6 +48,10 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-tf_model.h5", 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-tf_model.h5", 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-tf_model.h5", + 'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-tf_model.h5", + 'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-tf_model.h5", + 'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-tf_model.h5", + 'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-tf_model.h5" } @@ -129,7 +133,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer): linear tensor, float32 with shape [batch_size, length, vocab_size]. Raises: ValueError: if mode is not valid. - + Shared weights logic adapted from https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 """ @@ -148,7 +152,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer): input_shape = shape_list(input_ids) else: input_shape = shape_list(inputs_embeds)[:-1] - + seq_length = input_shape[1] if position_ids is None: position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :] @@ -246,7 +250,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer): context_layer = tf.matmul(attention_probs, value_layer) context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) - context_layer = tf.reshape(context_layer, + context_layer = tf.reshape(context_layer, (batch_size, -1, self.all_head_size)) # (batch_size, seq_len_q, all_head_size) outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) @@ -591,7 +595,7 @@ BERT_START_DOCSTRING = r""" The BERT model was proposed in `model({'input_ids': input_ids, 'token_type_ids': token_type_ids})` Parameters: - config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. + config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. """ @@ -605,13 +609,13 @@ BERT_INPUTS_DOCSTRING = r""" (a) For sequence pairs: ``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]`` - + ``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1`` (b) For single sequences: ``tokens: [CLS] the dog is hairy . [SEP]`` - + ``token_type_ids: 0 0 0 0 0 0 0`` Bert is a model with absolute position embeddings so it's usually advised to pad the inputs on diff --git a/transformers/modeling_tf_utils.py b/transformers/modeling_tf_utils.py index 8d010e589e5..6fb4850b05c 100644 --- a/transformers/modeling_tf_utils.py +++ b/transformers/modeling_tf_utils.py @@ -24,7 +24,8 @@ import os import tensorflow as tf from .configuration_utils import PretrainedConfig -from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME, DUMMY_INPUTS +from .file_utils import (TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME, DUMMY_INPUTS, + cached_path, hf_bucket_url, is_remote_url) from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model logger = logging.getLogger(__name__) @@ -174,6 +175,7 @@ class TFPreTrainedModel(tf.keras.Model): pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards. @@ -255,10 +257,14 @@ class TFPreTrainedModel(tf.keras.Model): raise EnvironmentError("Error no file named {} found in directory {} or `from_pt` set to False".format( [WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path)) - elif os.path.isfile(pretrained_model_name_or_path): + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path + elif os.path.isfile(pretrained_model_name_or_path + ".index"): + archive_file = pretrained_model_name_or_path + ".index" else: - raise EnvironmentError("Error file {} not found".format(pretrained_model_name_or_path)) + archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=TF2_WEIGHTS_NAME) + if from_pt: + raise EnvironmentError("Loading a TF model from a PyTorch checkpoint is not supported when using a model identifier name.") # redirect to the cache, if necessary try: diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index ae515d68706..9bd99b25dcb 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -31,7 +31,8 @@ from torch.nn import CrossEntropyLoss from torch.nn import functional as F from .configuration_utils import PretrainedConfig -from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME, DUMMY_INPUTS +from .file_utils import (TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME, DUMMY_INPUTS, + cached_path, hf_bucket_url, is_remote_url) logger = logging.getLogger(__name__) @@ -272,6 +273,7 @@ class PreTrainedModel(nn.Module): pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``) @@ -370,11 +372,16 @@ class PreTrainedModel(nn.Module): raise EnvironmentError("Error no file named {} found in directory {} or `from_tf` set to False".format( [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"], pretrained_model_name_or_path)) - elif os.path.isfile(pretrained_model_name_or_path): + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path - else: - assert from_tf, "Error finding file {}, no file or TF 1.X checkpoint found".format(pretrained_model_name_or_path) + elif os.path.isfile(pretrained_model_name_or_path + ".index"): + assert from_tf, "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format( + pretrained_model_name_or_path + ".index") archive_file = pretrained_model_name_or_path + ".index" + else: + archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=WEIGHTS_NAME) + if from_tf: + raise EnvironmentError("Loading a PyTorch model from a TF checkpoint is not supported when using a model identifier name.") # redirect to the cache, if necessary try: diff --git a/transformers/tests/modeling_auto_test.py b/transformers/tests/modeling_auto_test.py index 9b7d920bc86..871a262fe8c 100644 --- a/transformers/tests/modeling_auto_test.py +++ b/transformers/tests/modeling_auto_test.py @@ -22,7 +22,7 @@ import logging from transformers import is_torch_available -from .utils import require_torch, slow +from .utils import require_torch, slow, SMALL_MODEL_IDENTIFIER if is_torch_available(): from transformers import (AutoConfig, BertConfig, @@ -92,6 +92,11 @@ class AutoModelTest(unittest.TestCase): self.assertIsNotNone(model) self.assertIsInstance(model, BertForQuestionAnswering) + def test_from_pretrained_identifier(self): + logging.basicConfig(level=logging.INFO) + model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER) + self.assertIsInstance(model, BertForMaskedLM) + if __name__ == "__main__": unittest.main() diff --git a/transformers/tests/modeling_tf_auto_test.py b/transformers/tests/modeling_tf_auto_test.py index 7ea48015d9b..7ab6eaa3d63 100644 --- a/transformers/tests/modeling_tf_auto_test.py +++ b/transformers/tests/modeling_tf_auto_test.py @@ -22,7 +22,7 @@ import logging from transformers import is_tf_available -from .utils import require_tf, slow +from .utils import require_tf, slow, SMALL_MODEL_IDENTIFIER if is_tf_available(): from transformers import (AutoConfig, BertConfig, @@ -93,6 +93,11 @@ class TFAutoModelTest(unittest.TestCase): self.assertIsNotNone(model) self.assertIsInstance(model, TFBertForQuestionAnswering) + def test_from_pretrained_identifier(self): + logging.basicConfig(level=logging.INFO) + model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, force_download=True) + self.assertIsInstance(model, TFBertForMaskedLM) + if __name__ == "__main__": unittest.main() diff --git a/transformers/tests/tokenization_auto_test.py b/transformers/tests/tokenization_auto_test.py index 18346d27688..0a894cac043 100644 --- a/transformers/tests/tokenization_auto_test.py +++ b/transformers/tests/tokenization_auto_test.py @@ -23,7 +23,7 @@ import logging from transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP -from .utils import slow +from .utils import slow, SMALL_MODEL_IDENTIFIER class AutoTokenizerTest(unittest.TestCase): @@ -42,6 +42,11 @@ class AutoTokenizerTest(unittest.TestCase): self.assertIsInstance(tokenizer, GPT2Tokenizer) self.assertGreater(len(tokenizer), 0) + def test_tokenizer_from_pretrained_identifier(self): + logging.basicConfig(level=logging.INFO) + tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER) + self.assertIsInstance(tokenizer, BertTokenizer) + self.assertEqual(len(tokenizer), 12) if __name__ == "__main__": unittest.main() diff --git a/transformers/tests/tokenization_bert_japanese_test.py b/transformers/tests/tokenization_bert_japanese_test.py new file mode 100644 index 00000000000..545193c7cce --- /dev/null +++ b/transformers/tests/tokenization_bert_japanese_test.py @@ -0,0 +1,191 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import unittest +from io import open + +from transformers.tokenization_bert import WordpieceTokenizer +from transformers.tokenization_bert_japanese import (BertJapaneseTokenizer, + MecabTokenizer, CharacterTokenizer, + VOCAB_FILES_NAMES) + +from .tokenization_tests_commons import CommonTestCases +from .utils import slow, custom_tokenizers + + +@custom_tokenizers +class BertJapaneseTokenizationTest(CommonTestCases.CommonTokenizerTester): + + tokenizer_class = BertJapaneseTokenizer + + def setUp(self): + super(BertJapaneseTokenizationTest, self).setUp() + + vocab_tokens = [u"[UNK]", u"[CLS]", u"[SEP]", + u"こんにちは", u"こん", u"にちは", u"ばんは", u"##こん", u"##にちは", u"##ばんは", + u"世界", u"##世界", u"、", u"##、", u"。", u"##。"] + + self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + + def get_tokenizer(self, **kwargs): + return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, **kwargs) + + def get_input_output_texts(self): + input_text = u"こんにちは、世界。 \nこんばんは、世界。" + output_text = u"こんにちは 、 世界 。 こんばんは 、 世界 。" + return input_text, output_text + + def test_full_tokenizer(self): + tokenizer = self.tokenizer_class(self.vocab_file) + + tokens = tokenizer.tokenize(u"こんにちは、世界。\nこんばんは、世界。") + self.assertListEqual(tokens, + [u"こんにちは", u"、", u"世界", u"。", + u"こん", u"##ばんは", u"、", u"世界", "。"]) + self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), + [3, 12, 10, 14, 4, 9, 12, 10, 14]) + + def test_mecab_tokenizer(self): + tokenizer = MecabTokenizer() + + self.assertListEqual( + tokenizer.tokenize(u" \tアップルストアでiPhone8 が \n 発売された 。 "), + [u"アップルストア", u"で", u"iPhone", u"8", u"が", + u"発売", u"さ", u"れ", u"た", u"。"]) + + def test_mecab_tokenizer_lower(self): + tokenizer = MecabTokenizer(do_lower_case=True) + + self.assertListEqual( + tokenizer.tokenize(u" \tアップルストアでiPhone8 が \n 発売された 。 "), + [u"アップルストア", u"で", u"iphone", u"8", u"が", + u"発売", u"さ", u"れ", u"た", u"。"]) + + def test_mecab_tokenizer_no_normalize(self): + tokenizer = MecabTokenizer(normalize_text=False) + + self.assertListEqual( + tokenizer.tokenize(u" \tアップルストアでiPhone8 が \n 発売された 。 "), + [u"アップルストア", u"で", u"iPhone", u"8", u"が", + u"発売", u"さ", u"れ", u"た", u" ", u"。"]) + + def test_wordpiece_tokenizer(self): + vocab_tokens = [u"[UNK]", u"[CLS]", u"[SEP]", + u"こんにちは", u"こん", u"にちは" u"ばんは", u"##こん", u"##にちは", u"##ばんは"] + + vocab = {} + for (i, token) in enumerate(vocab_tokens): + vocab[token] = i + tokenizer = WordpieceTokenizer(vocab=vocab, unk_token=u"[UNK]") + + self.assertListEqual(tokenizer.tokenize(u""), []) + + self.assertListEqual(tokenizer.tokenize(u"こんにちは"), + [u"こんにちは"]) + + self.assertListEqual(tokenizer.tokenize(u"こんばんは"), + [u"こん", u"##ばんは"]) + + self.assertListEqual(tokenizer.tokenize(u"こんばんは こんばんにちは こんにちは"), + [u"こん", u"##ばんは", u"[UNK]", u"こんにちは"]) + + @slow + def test_sequence_builders(self): + tokenizer = self.tokenizer_class.from_pretrained("bert-base-japanese") + + text = tokenizer.encode(u"ありがとう。", add_special_tokens=False) + text_2 = tokenizer.encode(u"どういたしまして。", add_special_tokens=False) + + encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) + encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) + + # 2 is for "[CLS]", 3 is for "[SEP]" + assert encoded_sentence == [2] + text + [3] + assert encoded_pair == [2] + text + [3] + text_2 + [3] + + +class BertJapaneseCharacterTokenizationTest(CommonTestCases.CommonTokenizerTester): + + tokenizer_class = BertJapaneseTokenizer + + def setUp(self): + super(BertJapaneseCharacterTokenizationTest, self).setUp() + + vocab_tokens = [u"[UNK]", u"[CLS]", u"[SEP]", + u"こ", u"ん", u"に", u"ち", u"は", u"ば", u"世", u"界", u"、", u"。"] + + self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + + def get_tokenizer(self, **kwargs): + return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, + subword_tokenizer_type="character", + **kwargs) + + def get_input_output_texts(self): + input_text = u"こんにちは、世界。 \nこんばんは、世界。" + output_text = u"こ ん に ち は 、 世 界 。 こ ん ば ん は 、 世 界 。" + return input_text, output_text + + def test_full_tokenizer(self): + tokenizer = self.tokenizer_class(self.vocab_file, + subword_tokenizer_type="character") + + tokens = tokenizer.tokenize(u"こんにちは、世界。 \nこんばんは、世界。") + self.assertListEqual(tokens, + [u"こ", u"ん", u"に", u"ち", u"は", u"、", u"世", u"界", u"。", + u"こ", u"ん", u"ば", u"ん", u"は", u"、", u"世", u"界", u"。"]) + self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), + [3, 4, 5, 6, 7, 11, 9, 10, 12, + 3, 4, 8, 4, 7, 11, 9, 10, 12]) + + def test_character_tokenizer(self): + vocab_tokens = [u"[UNK]", u"[CLS]", u"[SEP]", + u"こ", u"ん", u"に", u"ち", u"は", u"ば", u"世", u"界"u"、", u"。"] + + vocab = {} + for (i, token) in enumerate(vocab_tokens): + vocab[token] = i + tokenizer = CharacterTokenizer(vocab=vocab, unk_token=u"[UNK]") + + self.assertListEqual(tokenizer.tokenize(u""), []) + + self.assertListEqual(tokenizer.tokenize(u"こんにちは"), + [u"こ", u"ん", u"に", u"ち", u"は"]) + + self.assertListEqual(tokenizer.tokenize(u"こんにちほ"), + [u"こ", u"ん", u"に", u"ち", u"[UNK]"]) + + @slow + def test_sequence_builders(self): + tokenizer = self.tokenizer_class.from_pretrained("bert-base-japanese-char") + + text = tokenizer.encode(u"ありがとう。", add_special_tokens=False) + text_2 = tokenizer.encode(u"どういたしまして。", add_special_tokens=False) + + encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) + encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) + + # 2 is for "[CLS]", 3 is for "[SEP]" + assert encoded_sentence == [2] + text + [3] + assert encoded_pair == [2] + text + [3] + text_2 + [3] + + + diff --git a/transformers/tests/utils.py b/transformers/tests/utils.py index 7a51ab612b6..c950ad8f17e 100644 --- a/transformers/tests/utils.py +++ b/transformers/tests/utils.py @@ -6,18 +6,26 @@ from distutils.util import strtobool from transformers.file_utils import _tf_available, _torch_available -try: - run_slow = os.environ["RUN_SLOW"] -except KeyError: - # RUN_SLOW isn't set, default to skipping slow tests. - _run_slow_tests = False -else: - # RUN_SLOW is set, convert it to True or False. +SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" + + +def parse_flag_from_env(key, default=False): try: - _run_slow_tests = strtobool(run_slow) - except ValueError: - # More values are supported, but let's keep the message simple. - raise ValueError("If set, RUN_SLOW must be yes or no.") + value = os.environ[key] + except KeyError: + # KEY isn't set, default to `default`. + _value = default + else: + # KEY is set, convert it to True or False. + try: + _value = strtobool(value) + except ValueError: + # More values are supported, but let's keep the message simple. + raise ValueError("If set, {} must be yes or no.".format(key)) + return _value + +_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) +_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False) def slow(test_case): @@ -33,6 +41,19 @@ def slow(test_case): return test_case +def custom_tokenizers(test_case): + """ + Decorator marking a test for a custom tokenizer. + + Custom tokenizers require additional dependencies, and are skipped + by default. Set the RUN_CUSTOM_TOKENIZERS environment variable + to a truthy value to run them. + """ + if not _run_custom_tokenizers: + test_case = unittest.skip("test of custom tokenizers")(test_case) + return test_case + + def require_torch(test_case): """ Decorator marking a test that requires PyTorch. @@ -59,6 +80,6 @@ def require_tf(test_case): if _torch_available: # Set the USE_CUDA environment variable to select a GPU. - torch_device = "cuda" if os.environ.get("USE_CUDA") else "cpu" + torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu" else: torch_device = None diff --git a/transformers/tokenization_auto.py b/transformers/tokenization_auto.py index de877af5f32..173dee0e2b2 100644 --- a/transformers/tokenization_auto.py +++ b/transformers/tokenization_auto.py @@ -19,6 +19,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera import logging from .tokenization_bert import BertTokenizer +from .tokenization_bert_japanese import BertJapaneseTokenizer from .tokenization_openai import OpenAIGPTTokenizer from .tokenization_gpt2 import GPT2Tokenizer from .tokenization_ctrl import CTRLTokenizer @@ -75,6 +76,7 @@ class AutoTokenizer(object): - contains `albert`: AlbertTokenizer (ALBERT model) - contains `camembert`: CamembertTokenizer (CamemBERT model) - contains `roberta`: RobertaTokenizer (RoBERTa model) + - contains `bert-base-japanese`: BertJapaneseTokenizer (Bert model) - contains `bert`: BertTokenizer (Bert model) - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model) - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model) @@ -87,6 +89,7 @@ class AutoTokenizer(object): pretrained_model_name_or_path: either: - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a predefined tokenizer that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``. - (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``. @@ -109,8 +112,14 @@ class AutoTokenizer(object): Examples:: - tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') # Download vocabulary from S3 and cache. - tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')` + # Download vocabulary from S3 and cache. + tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') + + # Download vocabulary from S3 (user-uploaded) and cache. + tokenizer = AutoTokenizer.from_pretrained('dbmdz/bert-base-german-cased') + + # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`) + tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/') """ if 't5' in pretrained_model_name_or_path: @@ -123,6 +132,8 @@ class AutoTokenizer(object): return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) elif 'roberta' in pretrained_model_name_or_path: return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + elif 'bert-base-japanese' in pretrained_model_name_or_path: + return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) elif 'bert' in pretrained_model_name_or_path: return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) elif 'openai-gpt' in pretrained_model_name_or_path: diff --git a/transformers/tokenization_bert_japanese.py b/transformers/tokenization_bert_japanese.py new file mode 100644 index 00000000000..0ff45cbfe71 --- /dev/null +++ b/transformers/tokenization_bert_japanese.py @@ -0,0 +1,253 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes.""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import collections +import logging +import os +import six +import unicodedata +from io import open + +from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer, load_vocab +from .tokenization_utils import PreTrainedTokenizer + +logger = logging.getLogger(__name__) + +VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} + +PRETRAINED_VOCAB_FILES_MAP = { + 'vocab_file': + { + 'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-vocab.txt", + 'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-vocab.txt", + 'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-vocab.txt", + 'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-vocab.txt" + } +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + 'bert-base-japanese': 512, + 'bert-base-japanese-whole-word-masking': 512, + 'bert-base-japanese-char': 512, + 'bert-base-japanese-char-whole-word-masking': 512 +} + +PRETRAINED_INIT_CONFIGURATION = { + 'bert-base-japanese': { + 'do_lower_case': False, + 'word_tokenizer_type': 'mecab', + 'subword_tokenizer_type': 'wordpiece' + }, + 'bert-base-japanese-whole-word-masking':{ + 'do_lower_case': False, + 'word_tokenizer_type': 'mecab', + 'subword_tokenizer_type': 'wordpiece' + }, + 'bert-base-japanese-char': { + 'do_lower_case': False, + 'word_tokenizer_type': 'mecab', + 'subword_tokenizer_type': 'character' + }, + 'bert-base-japanese-char-whole-word-masking': { + 'do_lower_case': False, + 'word_tokenizer_type': 'mecab', + 'subword_tokenizer_type': 'character' + } +} + + +class BertJapaneseTokenizer(BertTokenizer): + """BERT tokenizer for Japanese text""" + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + def __init__(self, vocab_file, do_lower_case=False, + do_word_tokenize=True, do_subword_tokenize=True, + word_tokenizer_type='basic', subword_tokenizer_type='wordpiece', + never_split=None, unk_token='[UNK]', sep_token='[SEP]', + pad_token='[PAD]', cls_token='[CLS]', mask_token='[MASK]', **kwargs): + """Constructs a MecabBertTokenizer. + + Args: + **vocab_file**: Path to a one-wordpiece-per-line vocabulary file. + **do_lower_case**: (`optional`) boolean (default True) + Whether to lower case the input. + Only has an effect when do_basic_tokenize=True. + **do_word_tokenize**: (`optional`) boolean (default True) + Whether to do word tokenization. + **do_subword_tokenize**: (`optional`) boolean (default True) + Whether to do subword tokenization. + **word_tokenizer_type**: (`optional`) string (default "basic") + Type of word tokenizer. + **subword_tokenizer_type**: (`optional`) string (default "wordpiece") + Type of subword tokenizer. + """ + super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token, + pad_token=pad_token, cls_token=cls_token, + mask_token=mask_token, **kwargs) + self.max_len_single_sentence = self.max_len - 2 # take into account special tokens + self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens + + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " + "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict( + [(ids, tok) for tok, ids in self.vocab.items()]) + + self.do_word_tokenize = do_word_tokenize + if do_word_tokenize: + if word_tokenizer_type == 'basic': + self.word_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=False) + elif word_tokenizer_type == 'mecab': + self.word_tokenizer = MecabTokenizer(do_lower_case=do_lower_case, + never_split=never_split) + else: + raise ValueError( + "Invalid word_tokenizer_type '{}' is specified.".format(word_tokenizer_type)) + + self.do_subword_tokenize = do_subword_tokenize + if do_subword_tokenize: + if subword_tokenizer_type == 'wordpiece': + self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, + unk_token=self.unk_token) + elif subword_tokenizer_type == 'character': + self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, + unk_token=self.unk_token) + else: + raise ValueError( + "Invalid subword_tokenizer_type '{}' is specified.".format(subword_tokenizer_type)) + + + def _tokenize(self, text): + if self.do_word_tokenize: + tokens = self.word_tokenizer.tokenize(text, + never_split=self.all_special_tokens) + else: + tokens = [text] + + if self.do_subword_tokenize: + split_tokens = [sub_token for token in tokens + for sub_token in self.subword_tokenizer.tokenize(token)] + else: + split_tokens = tokens + + return split_tokens + + +class MecabTokenizer(object): + """Runs basic tokenization with MeCab morphological parser.""" + + def __init__(self, do_lower_case=False, never_split=None, normalize_text=True): + """Constructs a MecabTokenizer. + + Args: + **do_lower_case**: (`optional`) boolean (default True) + Whether to lower case the input. + **never_split**: (`optional`) list of str + Kept for backward compatibility purposes. + Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) + List of token not to split. + **normalize_text**: (`optional`) boolean (default True) + Whether to apply unicode normalization to text before tokenization. + """ + self.do_lower_case = do_lower_case + self.never_split = never_split if never_split is not None else [] + self.normalize_text = normalize_text + + import MeCab + self.mecab = MeCab.Tagger() + + def tokenize(self, text, never_split=None, **kwargs): + """Tokenizes a piece of text.""" + if self.normalize_text: + text = unicodedata.normalize('NFKC', text) + + never_split = self.never_split + (never_split if never_split is not None else []) + tokens = [] + + if six.PY2: + mecab_output = self.mecab.parse(text.encode('utf-8')).decode('utf-8') + else: + mecab_output = self.mecab.parse(text) + + cursor = 0 + for line in mecab_output.split('\n'): + if line == 'EOS': + break + + token, _ = line.split('\t') + token_start = text.index(token, cursor) + token_end = token_start + len(token) + if self.do_lower_case and token not in never_split: + token = token.lower() + + tokens.append(token) + cursor = token_end + + return tokens + + +class CharacterTokenizer(object): + """Runs Character tokenziation.""" + + def __init__(self, vocab, unk_token, normalize_text=True): + """Constructs a CharacterTokenizer. + + Args: + **vocab**: + Vocabulary object. + **unk_token**: str + A special symbol for out-of-vocabulary token. + **normalize_text**: (`optional`) boolean (default True) + Whether to apply unicode normalization to text before tokenization. + """ + self.vocab = vocab + self.unk_token = unk_token + self.normalize_text = normalize_text + + def tokenize(self, text): + """Tokenizes a piece of text into characters. + + For example: + input = "apple" + output = ["a", "p", "p", "l", "e"] + Args: + text: A single token or whitespace separated tokens. + This should have already been passed through `BasicTokenizer`. + Returns: + A list of characters. + """ + if self.normalize_text: + text = unicodedata.normalize('NFKC', text) + + output_tokens = [] + for i, char in enumerate(text): + if char not in self.vocab: + output_tokens.append(self.unk_token) + continue + + output_tokens.append(char) + + return output_tokens diff --git a/transformers/tokenization_utils.py b/transformers/tokenization_utils.py index f4395cd82cd..317ecd167b7 100644 --- a/transformers/tokenization_utils.py +++ b/transformers/tokenization_utils.py @@ -25,7 +25,7 @@ import itertools import re from io import open -from .file_utils import cached_path, is_tf_available, is_torch_available +from .file_utils import cached_path, is_remote_url, hf_bucket_url, is_tf_available, is_torch_available if is_tf_available(): import tensorflow as tf @@ -226,7 +226,7 @@ class PreTrainedTokenizer(object): self.max_len = max_len if max_len is not None else int(1e12) - # Padding side is right by default and over-riden in subclsses. If specified in the kwargs, it is changed. + # Padding side is right by default and over-riden in subclasses. If specified in the kwargs, it is changed. self.padding_side = kwargs.pop('padding_side', self.padding_side) # Added tokens @@ -255,6 +255,7 @@ class PreTrainedTokenizer(object): pretrained_model_name_or_path: either: - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a predefined tokenizer that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``. - (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``. @@ -282,6 +283,9 @@ class PreTrainedTokenizer(object): # Download vocabulary from S3 and cache. tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + # Download vocabulary from S3 (user-uploaded) and cache. + tokenizer = BertTokenizer.from_pretrained('dbmdz/bert-base-german-cased') + # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`) tokenizer = BertTokenizer.from_pretrained('./test/saved_model/') @@ -327,12 +331,15 @@ class PreTrainedTokenizer(object): if os.path.isdir(pretrained_model_name_or_path): # If a directory is provided we look for the standard filenames full_file_name = os.path.join(pretrained_model_name_or_path, file_name) - else: + if not os.path.exists(full_file_name): + logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) + full_file_name = None + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): # If a path to a file is provided we use it (will only work for non-BPE tokenizer using a single vocabulary file) full_file_name = pretrained_model_name_or_path - if not os.path.exists(full_file_name): - logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) - full_file_name = None + else: + full_file_name = hf_bucket_url(pretrained_model_name_or_path, postfix=file_name) + vocab_files[file_id] = full_file_name # Look for the additional tokens files @@ -628,7 +635,6 @@ class PreTrainedTokenizer(object): Take care of added tokens. text: The sequence to be encoded. - return_tokens_mapped_to_origin: (optional) Set to True to return the index of each token in the initial whitespace tokenization. (default False). **kwargs: passed to the child `self.tokenize()` method """ def lowercase_text(t): @@ -663,7 +669,7 @@ class PreTrainedTokenizer(object): return result def split_on_tokens(tok_list, text): - if not text: + if not text.strip(): return [] if not tok_list: return self._tokenize(text, **kwargs) @@ -917,7 +923,7 @@ class PreTrainedTokenizer(object): return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant or PyTorch torch.Tensor instead of a list of python integers. return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True). - return_attention_mask: (optional) Set to False to avoir returning attention mask (default True) + return_attention_mask: (optional) Set to False to avoid returning attention mask (default True) return_overflowing_tokens: (optional) Set to True to return overflowing token information (default False). return_special_tokens_mask: (optional) Set to True to return special tokens mask information (default False). @@ -962,24 +968,13 @@ class PreTrainedTokenizer(object): if add_special_tokens: sequence = self.build_inputs_with_special_tokens(ids, pair_ids) token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) - special_tokens_mask = self.get_special_tokens_mask(ids, pair_ids) else: sequence = ids + pair_ids if pair else ids token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else []) - special_tokens_mask = [0] * (len(ids) + (len(pair_ids) if pair else 0)) + if return_special_tokens_mask: encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) - # Prepare inputs as tensors if asked - if return_tensors == 'tf' and is_tf_available(): - sequence = tf.constant([sequence]) - token_type_ids = tf.constant([token_type_ids]) - elif return_tensors == 'pt' and is_torch_available(): - sequence = torch.tensor([sequence]) - token_type_ids = torch.tensor([token_type_ids]) - elif return_tensors is not None: - logger.warning("Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(return_tensors)) - encoded_inputs["input_ids"] = sequence if return_token_type_ids: encoded_inputs["token_type_ids"] = token_type_ids @@ -1003,7 +998,7 @@ class PreTrainedTokenizer(object): ) if pad_to_max_length and max_length is None and self.max_len > 10000: - logger.warning("Sequence can't be padded as the maximum ") + logger.warning("Sequence can't be padded as no maximum length is specified and the model maximum length is too high.") if needs_to_be_padded: difference = (max_length if max_length is not None else self.max_len) - len(encoded_inputs["input_ids"]) @@ -1016,10 +1011,9 @@ class PreTrainedTokenizer(object): if return_special_tokens_mask: encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference - elif self.padding_side == 'left': if return_attention_mask: - encoded_inputs["attention_mask"] = [0] * difference + [1] * len(encoded_inputs["input_ids"]) + encoded_inputs["attention_mask"] = [0] * difference + [1] * len(encoded_inputs["input_ids"]) if return_token_type_ids: encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs["token_type_ids"] if return_special_tokens_mask: @@ -1031,7 +1025,26 @@ class PreTrainedTokenizer(object): elif return_attention_mask: encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) - + + # Prepare inputs as tensors if asked + if return_tensors == 'tf' and is_tf_available(): + encoded_inputs["input_ids"] = tf.constant([encoded_inputs["input_ids"]]) + encoded_inputs["token_type_ids"] = tf.constant([encoded_inputs["token_type_ids"]]) + + if "attention_mask" in encoded_inputs: + encoded_inputs["attention_mask"] = tf.constant([encoded_inputs["attention_mask"]]) + + elif return_tensors == 'pt' and is_torch_available(): + encoded_inputs["input_ids"] = torch.tensor([encoded_inputs["input_ids"]]) + encoded_inputs["token_type_ids"] = torch.tensor([encoded_inputs["token_type_ids"]]) + + if "attention_mask" in encoded_inputs: + encoded_inputs["attention_mask"] = torch.tensor([encoded_inputs["attention_mask"]]) + elif return_tensors is not None: + logger.warning( + "Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format( + return_tensors)) + return encoded_inputs def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first', stride=0): diff --git a/transformers/tokenization_xlm.py b/transformers/tokenization_xlm.py index 6c9f8e5e5c2..8def80bec49 100644 --- a/transformers/tokenization_xlm.py +++ b/transformers/tokenization_xlm.py @@ -549,6 +549,10 @@ class XLMTokenizer(PreTrainedTokenizer): additional_special_tokens=additional_special_tokens, **kwargs) + + self.max_len_single_sentence = self.max_len - 2 # take into account special tokens + self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens + # cache of sm.MosesPunctNormalizer instance self.cache_moses_punct_normalizer = dict() # cache of sm.MosesTokenizer instance diff --git a/utils/link_tester.py b/utils/link_tester.py new file mode 100644 index 00000000000..fe3990d28c0 --- /dev/null +++ b/utils/link_tester.py @@ -0,0 +1,79 @@ +""" Link tester. + +This little utility reads all the python files in the repository, +scans for links pointing to S3 and tests the links one by one. Raises an error +at the end of the scan if at least one link was reported broken. +""" +import os +import re +import sys + +import requests + + +REGEXP_FIND_S3_LINKS = r"""([\"'])(https:\/\/s3)(.*)?\1""" + + +def list_python_files_in_repository(): + """ List all python files in the repository. + + This function assumes that the script is executed in the root folder. + """ + source_code_files = [] + for path, subdirs, files in os.walk("."): + if "templates" in path: + continue + for name in files: + if ".py" in name and ".pyc" not in name: + path_to_files = os.path.join(path, name) + source_code_files.append(path_to_files) + + return source_code_files + + +def find_all_links(file_paths): + links = [] + for path in file_paths: + links += scan_code_for_links(path) + + return links + + +def scan_code_for_links(source): + """ Scans the file to find links using a regular expression. + Returns a list of links. + """ + with open(source, 'r') as content: + content = content.read() + raw_links = re.findall(REGEXP_FIND_S3_LINKS, content) + links = [prefix + suffix for _, prefix, suffix in raw_links] + + return links + + +def check_all_links(links): + """ Check that the provided links are valid. + + Links are considered valid if a HEAD request to the server + returns a 200 status code. + """ + broken_links = [] + for link in links: + head = requests.head(link) + if head.status_code != 200: + broken_links.append(link) + + return broken_links + + +if __name__ == "__main__": + file_paths = list_python_files_in_repository() + links = find_all_links(file_paths) + broken_links = check_all_links(links) + print("Looking for broken links to pre-trained models/configs/tokenizers...") + if broken_links: + print("The following links did not respond:") + for link in broken_links: + print("- {}".format(link)) + sys.exit(1) + print("All links are ok.")