diff --git a/README.md b/README.md index daac69de9f1..caf415508f9 100644 --- a/README.md +++ b/README.md @@ -131,6 +131,7 @@ This package comprises the following classes that can be imported in Python and - Configuration classes for BERT, OpenAI GPT and Transformer-XL (in the respective [`modeling.py`](./pytorch_pretrained_bert/modeling.py), [`modeling_openai.py`](./pytorch_pretrained_bert/modeling_openai.py), [`modeling_transfo_xl.py`](./pytorch_pretrained_bert/modeling_transfo_xl.py) files): - `BertConfig` - Configuration class to store the configuration of a `BertModel` with utilities to read and write from JSON configuration files. - `OpenAIGPTConfig` - Configuration class to store the configuration of a `OpenAIGPTModel` with utilities to read and write from JSON configuration files. + - `GPT2Config` - Configuration class to store the configuration of a `GPT2Model` with utilities to read and write from JSON configuration files. - `TransfoXLConfig` - Configuration class to store the configuration of a `TransfoXLModel` with utilities to read and write from JSON configuration files. The repository further comprises: @@ -461,10 +462,12 @@ Here is a detailed documentation of the classes in the package and how to use th | Sub-section | Description | |-|-| -| [Loading Google AI's/OpenAI's pre-trained weights](#loading-google-ai-or-openai-pre-trained-weights-or-pytorch-dump) | How to load Google AI/OpenAI's pre-trained weight or a PyTorch saved instance | -| [PyTorch models](#PyTorch-models) | API of the BERT, GPT, GPT-2 and Transformer-XL PyTorch model classes | +| [Loading pre-trained weights](#loading-google-ai-or-openai-pre-trained-weights-or-pytorch-dump) | How to load Google AI/OpenAI's pre-trained weight or a PyTorch saved instance | +| [Serialization best-practices](#serialization-best-practices) | How to save and reload a fine-tuned model | +| [Configurations](#configurations) | API of the configuration classes for BERT, GPT, GPT-2 and Transformer-XL | +| [Models](#models) | API of the PyTorch model classes for BERT, GPT, GPT-2 and Transformer-XL | | [Tokenizers](#tokenizers) | API of the tokenizers class for BERT, GPT, GPT-2 and Transformer-XL| -| [Optimizers](#optimizerss) | API of the optimizers | +| [Optimizers](#optimizers) | API of the optimizers | ### Loading Google AI or OpenAI pre-trained weights or PyTorch dump @@ -524,7 +527,101 @@ model = GPT2Model.from_pretrained('gpt2') ``` -### PyTorch models +### Serialization best-practices + +This section explain how you can save and re-load a fine-tuned model (BERT, GPT, GPT-2 and Transformer-XL). +There are three types of files you need to save to be able to reload a fine-tuned model: + +- the model it-self which should be saved following PyTorch serialization [best practices](https://pytorch.org/docs/stable/notes/serialization.html#best-practices), +- the configuration file of the model which is saved as a JSON file, and +- the vocabulary (and the merges for the BPE-based models GPT and GPT-2). + +Here is the recommended way of saving the model, configuration and vocabulary to an `output_dir` directory and reloading the model and tokenizer afterwards: + +```python +from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME + +output_dir = "./models/" + +# Step 1: Save a model, configuration and vocabulary that you have fine-tuned + +# If we have a distributed model, save only the encapsulated model +# (it was wrapped in PyTorch DistributedDataParallel or DataParallel) +model_to_save = model.module if hasattr(model, 'module') else model + +# If we save using the predefined names, we can load using `from_pretrained` +output_model_file = os.path.join(output_dir, WEIGHTS_NAME) +output_config_file = os.path.join(output_dir, CONFIG_NAME) + +torch.save(model_to_save.state_dict(), output_model_file) +model_to_save.config.to_json_file(output_config_file) +tokenizer.save_vocabulary(output_dir) + +# Step 2: Re-load the saved model and vocabulary + +# Example for a Bert model +model = BertForQuestionAnswering.from_pretrained(output_dir) +tokenizer = BertTokenizer.from_pretrained(output_dir, do_lower_case=args.do_lower_case) # Add specific options if needed +# Example for a GPT model +model = OpenAIGPTDoubleHeadsModel.from_pretrained(output_dir) +tokenizer = OpenAIGPTTokenizer.from_pretrained(output_dir) +``` + +Here is another way you can save and reload the model if you want to use specific paths for each type of files: + +```python +output_model_file = "./models/my_own_model_file.bin" +output_config_file = "./models/my_own_config_file.bin" +output_vocab_file = "./models/my_own_vocab_file.bin" + +# Step 1: Save a model, configuration and vocabulary that you have fine-tuned + +# If we have a distributed model, save only the encapsulated model +# (it was wrapped in PyTorch DistributedDataParallel or DataParallel) +model_to_save = model.module if hasattr(model, 'module') else model + +torch.save(model_to_save.state_dict(), output_model_file) +model_to_save.config.to_json_file(output_config_file) +tokenizer.save_vocabulary(output_vocab_file) + +# Step 2: Re-load the saved model and vocabulary + +# We didn't save using the predefined WEIGHTS_NAME, CONFIG_NAME names, we cannot load using `from_pretrained`. +# Here is how to do it in this situation: + +# Example for a Bert model +config = BertConfig.from_json_file(output_config_file) +model = BertForQuestionAnswering(config) +state_dict = torch.load(output_model_file) +model.load_state_dict(state_dict) +tokenizer = BertTokenizer(output_vocab_file, do_lower_case=args.do_lower_case) + +# Example for a GPT model +config = OpenAIGPTConfig.from_json_file(output_config_file) +model = OpenAIGPTDoubleHeadsModel(config) +state_dict = torch.load(output_model_file) +model.load_state_dict(state_dict) +tokenizer = OpenAIGPTTokenizer(output_vocab_file) +``` + +### Configurations + +Models (BERT, GPT, GPT-2 and Transformer-XL) are defined and build from configuration classes which containes the parameters of the models (number of layers, dimensionalities...) and a few utilities to read and write from JSON configuration files. The respective configuration classes are: + +- `BertConfig` for `BertModel` and BERT classes instances. +- `OpenAIGPTConfig` for `OpenAIGPTModel` and OpenAI GPT classes instances. +- `GPT2Config` for `GPT2Model` and OpenAI GPT-2 classes instances. +- `TransfoXLConfig` for `TransfoXLModel` and Transformer-XL classes instances. + +These configuration classes contains a few utilities to load and save configurations: + +- `from_dict(cls, json_object)`: A class method to construct a configuration from a Python dictionary of parameters. Returns an instance of the configuration class. +- `from_json_file(cls, json_file)`: A class method to construct a configuration from a json file of parameters. Returns an instance of the configuration class. +- `to_dict()`: Serializes an instance to a Python dictionary. Returns a dictionary. +- `to_json_string()`: Serializes an instance to a JSON string. Returns a string. +- `to_json_file(json_file_path)`: Save an instance to a json file. + +### Models #### 1. `BertModel` @@ -796,8 +893,7 @@ This model *outputs*: - `multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices] - `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as a torch.FloatTensors. They can be reused to speed up sequential decoding (see the `run_gpt2.py` example). - -### Tokenizers: +### Tokenizers #### `BertTokenizer` @@ -816,6 +912,7 @@ and three methods: - `tokenize(text)`: convert a `str` in a list of `str` tokens by (1) performing basic tokenization and (2) WordPiece tokenization. - `convert_tokens_to_ids(tokens)`: convert a list of `str` tokens in a list of `int` indices in the vocabulary. - `convert_ids_to_tokens(tokens)`: convert a list of `int` indices in a list of `str` tokens in the vocabulary. +- `save_vocabulary(directory_path)`: save the vocabulary file to `directory_path`. Return the path to the saved vocabulary file: `vocab_file_path`. The vocabulary can be reloaded with `BertTokenizer.from_pretrained('vocab_file_path')` or `BertTokenizer.from_pretrained('directory_path')`. Please refer to the doc strings and code in [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) for the details of the `BasicTokenizer` and `WordpieceTokenizer` classes. In general it is recommended to use `BertTokenizer` unless you know what you are doing. @@ -837,6 +934,7 @@ and five methods: - `convert_ids_to_tokens(tokens)`: convert a list of `int` indices in a list of `str` tokens in the vocabulary. - `set_special_tokens(self, special_tokens)`: update the list of special tokens (see above arguments) - `decode(ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)`: decode a list of `int` indices in a string and do some post-processing if needed: (i) remove special tokens from the output and (ii) clean up tokenization spaces. +- `save_vocabulary(directory_path)`: save the vocabulary, merge and special tokens files to `directory_path`. Return the path to the three files: `vocab_file_path`, `merge_file_path`, `special_tokens_file_path`. The vocabulary can be reloaded with `OpenAIGPTTokenizer.from_pretrained('directory_path')`. Please refer to the doc strings and code in [`tokenization_openai.py`](./pytorch_pretrained_bert/tokenization_openai.py) for the details of the `OpenAIGPTTokenizer`. @@ -844,6 +942,8 @@ Please refer to the doc strings and code in [`tokenization_openai.py`](./pytorch `TransfoXLTokenizer` perform word tokenization. This tokenizer can be used for adaptive softmax and has utilities for counting tokens in a corpus to create a vocabulary ordered by toekn frequency (for adaptive softmax). See the adaptive softmax paper ([Efficient softmax approximation for GPUs](http://arxiv.org/abs/1609.04309)) for more details. +The API is similar to the API of `BertTokenizer` (see above). + Please refer to the doc strings and code in [`tokenization_transfo_xl.py`](./pytorch_pretrained_bert/tokenization_transfo_xl.py) for the details of these additional methods in `TransfoXLTokenizer`. #### `GPT2Tokenizer` @@ -860,11 +960,11 @@ and two methods: - `encode(text)`: convert a `str` in a list of `int` tokens by performing byte-level BPE. - `decode(tokens)`: convert back a list of `int` tokens in a `str`. +- `save_vocabulary(directory_path)`: save the vocabulary, merge and special tokens files to `directory_path`. Return the path to the three files: `vocab_file_path`, `merge_file_path`, `special_tokens_file_path`. The vocabulary can be reloaded with `OpenAIGPTTokenizer.from_pretrained('directory_path')`. Please refer to [`tokenization_gpt2.py`](./pytorch_pretrained_bert/tokenization_gpt2.py) for more details on the `GPT2Tokenizer`. - -### Optimizers: +### Optimizers #### `BertAdam` @@ -1174,18 +1274,20 @@ To get these results we used a combination of: Here is the full list of hyper-parameters for this run: ```bash +export SQUAD_DIR=/path/to/SQUAD + python ./run_squad.py \ --bert_model bert-large-uncased \ --do_train \ --do_predict \ --do_lower_case \ - --train_file $SQUAD_TRAIN \ - --predict_file $SQUAD_EVAL \ + --train_file $SQUAD_DIR/train-v1.1.json \ + --predict_file $SQUAD_DIR/dev-v1.1.json \ --learning_rate 3e-5 \ --num_train_epochs 2 \ --max_seq_length 384 \ --doc_stride 128 \ - --output_dir $OUTPUT_DIR \ + --output_dir /tmp/debug_squad/ \ --train_batch_size 24 \ --gradient_accumulation_steps 2 ``` @@ -1194,18 +1296,20 @@ If you have a recent GPU (starting from NVIDIA Volta series), you should try **1 Here is an example of hyper-parameters for a FP16 run we tried: ```bash +export SQUAD_DIR=/path/to/SQUAD + python ./run_squad.py \ --bert_model bert-large-uncased \ --do_train \ --do_predict \ --do_lower_case \ - --train_file $SQUAD_TRAIN \ - --predict_file $SQUAD_EVAL \ + --train_file $SQUAD_DIR/train-v1.1.json \ + --predict_file $SQUAD_DIR/dev-v1.1.json \ --learning_rate 3e-5 \ --num_train_epochs 2 \ --max_seq_length 384 \ --doc_stride 128 \ - --output_dir $OUTPUT_DIR \ + --output_dir /tmp/debug_squad/ \ --train_batch_size 24 \ --fp16 \ --loss_scale 128 diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 4268c41ec60..b90ac494e4f 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -35,14 +35,11 @@ from torch.nn import CrossEntropyLoss, MSELoss from scipy.stats import pearsonr, spearmanr from sklearn.metrics import matthews_corrcoef, f1_score -from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE -from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME +from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME +from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear -logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', - datefmt = '%m/%d/%Y %H:%M:%S', - level = logging.INFO) logger = logging.getLogger(__name__) @@ -697,6 +694,11 @@ def main(): n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') + + logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) + logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( device, n_gpu, bool(args.local_rank != -1), args.fp16)) @@ -857,18 +859,21 @@ def main(): optimizer.zero_grad() global_step += 1 - # Save a trained model and the associated configuration + if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): + # Save a trained model, configuration and tokenizer model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self - output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) - torch.save(model_to_save.state_dict(), output_model_file) - output_config_file = os.path.join(args.output_dir, CONFIG_NAME) - with open(output_config_file, 'w') as f: - f.write(model_to_save.config.to_json_string()) - # Load a trained model and config that you have fine-tuned - config = BertConfig(output_config_file) - model = BertForSequenceClassification(config, num_labels=num_labels) - model.load_state_dict(torch.load(output_model_file)) + # If we save using the predefined names, we can load using `from_pretrained` + output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) + output_config_file = os.path.join(args.output_dir, CONFIG_NAME) + + torch.save(model_to_save.state_dict(), output_model_file) + model_to_save.config.to_json_file(output_config_file) + tokenizer.save_vocabulary(args.output_dir) + + # Load a trained model and vocabulary that you have fine-tuned + model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels) + tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) else: model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels) model.to(device) diff --git a/examples/run_openai_gpt.py b/examples/run_openai_gpt.py index ee30a7a0a4c..cb5aa8d9cbd 100644 --- a/examples/run_openai_gpt.py +++ b/examples/run_openai_gpt.py @@ -39,7 +39,8 @@ import torch from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset) -from pytorch_pretrained_bert import OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer, OpenAIAdam, cached_path +from pytorch_pretrained_bert import (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer, + OpenAIAdam, cached_path, WEIGHTS_NAME, CONFIG_NAME) ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz" @@ -218,15 +219,20 @@ def main(): # Save a trained model if args.do_train: + # Save a trained model, configuration and tokenizer model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self - output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") - config = model.config - torch.save(model_to_save.state_dict(), output_model_file) - # Load a trained model that you have fine-tuned - model_state_dict = torch.load(output_model_file) - model = OpenAIGPTDoubleHeadsModel(config) - model.load_state_dict(model_state_dict) + # If we save using the predefined names, we can load using `from_pretrained` + output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) + output_config_file = os.path.join(args.output_dir, CONFIG_NAME) + + torch.save(model_to_save.state_dict(), output_model_file) + model_to_save.config.to_json_file(output_config_file) + tokenizer.save_vocabulary(args.output_dir) + + # Load a trained model and vocabulary that you have fine-tuned + model = OpenAIGPTDoubleHeadsModel.from_pretrained(args.output_dir) + tokenizer = OpenAIGPTTokenizer.from_pretrained(args.output_dir) model.to(device) if args.do_eval: diff --git a/examples/run_squad.py b/examples/run_squad.py index 043b795326c..410fd852988 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -34,8 +34,8 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm, trange -from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE -from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig, WEIGHTS_NAME, CONFIG_NAME +from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME +from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear from pytorch_pretrained_bert.tokenization import (BasicTokenizer, BertTokenizer, @@ -46,9 +46,6 @@ if sys.version_info[0] == 2: else: import pickle -logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', - datefmt = '%m/%d/%Y %H:%M:%S', - level = logging.INFO) logger = logging.getLogger(__name__) @@ -837,7 +834,17 @@ def main(): parser.add_argument('--null_score_diff_threshold', type=float, default=0.0, help="If null_score - best_non_null is greater than the threshold predict null.") + parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") + parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") args = parser.parse_args() + print(args) + + if args.server_ip and args.server_port: + # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script + import ptvsd + print("Waiting for debugger attach") + ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) + ptvsd.wait_for_attach() if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") @@ -848,6 +855,11 @@ def main(): n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') + + logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) + logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( device, n_gpu, bool(args.local_rank != -1), args.fp16)) @@ -983,7 +995,7 @@ def main(): model.train() for _ in trange(int(args.num_train_epochs), desc="Epoch"): - for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): + for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])): if n_gpu == 1: batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self input_ids, input_mask, segment_ids, start_positions, end_positions = batch @@ -1008,19 +1020,21 @@ def main(): optimizer.zero_grad() global_step += 1 - if args.do_train: - # Save a trained model and the associated configuration + if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): + # Save a trained model, configuration and tokenizer model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self - output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) - torch.save(model_to_save.state_dict(), output_model_file) - output_config_file = os.path.join(args.output_dir, CONFIG_NAME) - with open(output_config_file, 'w') as f: - f.write(model_to_save.config.to_json_string()) - # Load a trained model and config that you have fine-tuned - config = BertConfig(output_config_file) - model = BertForQuestionAnswering(config) - model.load_state_dict(torch.load(output_model_file)) + # If we save using the predefined names, we can load using `from_pretrained` + output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) + output_config_file = os.path.join(args.output_dir, CONFIG_NAME) + + torch.save(model_to_save.state_dict(), output_model_file) + model_to_save.config.to_json_file(output_config_file) + tokenizer.save_vocabulary(args.output_dir) + + # Load a trained model and vocabulary that you have fine-tuned + model = BertForQuestionAnswering.from_pretrained(args.output_dir) + tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) else: model = BertForQuestionAnswering.from_pretrained(args.bert_model) @@ -1054,7 +1068,7 @@ def main(): model.eval() all_results = [] logger.info("Start evaluating") - for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"): + for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating", disable=args.local_rank not in [-1, 0]): if len(all_results) % 1000 == 0: logger.info("Processing example: %d" % (len(all_results))) input_ids = input_ids.to(device) diff --git a/examples/run_swag.py b/examples/run_swag.py index f193582640c..a6cfdbe311d 100644 --- a/examples/run_swag.py +++ b/examples/run_swag.py @@ -32,8 +32,8 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm, trange -from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE -from pytorch_pretrained_bert.modeling import (BertForMultipleChoice, BertConfig, WEIGHTS_NAME, CONFIG_NAME) +from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME +from pytorch_pretrained_bert.modeling import BertForMultipleChoice, BertConfig from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear from pytorch_pretrained_bert.tokenization import BertTokenizer @@ -473,18 +473,20 @@ def main(): if args.do_train: - # Save a trained model and the associated configuration + # Save a trained model, configuration and tokenizer model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self - output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) - torch.save(model_to_save.state_dict(), output_model_file) - output_config_file = os.path.join(args.output_dir, CONFIG_NAME) - with open(output_config_file, 'w') as f: - f.write(model_to_save.config.to_json_string()) - # Load a trained model and config that you have fine-tuned - config = BertConfig(output_config_file) - model = BertForMultipleChoice(config, num_choices=4) - model.load_state_dict(torch.load(output_model_file)) + # If we save using the predefined names, we can load using `from_pretrained` + output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) + output_config_file = os.path.join(args.output_dir, CONFIG_NAME) + + torch.save(model_to_save.state_dict(), output_model_file) + model_to_save.config.to_json_file(output_config_file) + tokenizer.save_vocabulary(args.output_dir) + + # Load a trained model and vocabulary that you have fine-tuned + model = BertForMultipleChoice.from_pretrained(args.output_dir, num_choices=4) + tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) else: model = BertForMultipleChoice.from_pretrained(args.bert_model, num_choices=4) model.to(device) diff --git a/examples/run_transfo_xl.py b/examples/run_transfo_xl.py index 8139f28baf5..0ea7b320536 100644 --- a/examples/run_transfo_xl.py +++ b/examples/run_transfo_xl.py @@ -28,7 +28,7 @@ import math import torch -from pytorch_pretrained_bert import TransfoXLLMHeadModel, TransfoXLCorpus +from pytorch_pretrained_bert import TransfoXLLMHeadModel, TransfoXLCorpus, TransfoXLTokenizer logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', @@ -80,6 +80,7 @@ def main(): # The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax # and tokenizing the dataset # The pre-processed corpus is a convertion (using the conversion script ) + tokenizer = TransfoXLTokenizer.from_pretrained(args.model_name) corpus = TransfoXLCorpus.from_pretrained(args.model_name) ntokens = len(corpus.vocab) diff --git a/pytorch_pretrained_bert/__init__.py b/pytorch_pretrained_bert/__init__.py index bd455b8d9cb..28d215d8bdf 100644 --- a/pytorch_pretrained_bert/__init__.py +++ b/pytorch_pretrained_bert/__init__.py @@ -21,4 +21,4 @@ from .modeling_gpt2 import (GPT2Config, GPT2Model, from .optimization import BertAdam from .optimization_openai import OpenAIAdam -from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path +from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME diff --git a/pytorch_pretrained_bert/file_utils.py b/pytorch_pretrained_bert/file_utils.py index 8601edde231..6de7e259e5c 100644 --- a/pytorch_pretrained_bert/file_utils.py +++ b/pytorch_pretrained_bert/file_utils.py @@ -33,6 +33,9 @@ except (AttributeError, ImportError): PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) +CONFIG_NAME = "config.json" +WEIGHTS_NAME = "pytorch_model.bin" + logger = logging.getLogger(__name__) # pylint: disable=invalid-name diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 037c6e97231..55c6d9d6116 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -32,7 +32,7 @@ import torch from torch import nn from torch.nn import CrossEntropyLoss -from .file_utils import cached_path +from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME logger = logging.getLogger(__name__) @@ -45,8 +45,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = { 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", } -CONFIG_NAME = 'bert_config.json' -WEIGHTS_NAME = 'pytorch_model.bin' +BERT_CONFIG_NAME = 'bert_config.json' TF_WEIGHTS_NAME = 'model.ckpt' def load_tf_weights_in_bert(model, tf_checkpoint_path): @@ -220,6 +219,11 @@ class BertConfig(object): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + def to_json_file(self, json_file_path): + """ Save this instance to a json file.""" + with open(json_file_path, "w", encoding='utf-8') as writer: + writer.write(self.to_json_string()) + try: from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm except ImportError: @@ -581,13 +585,16 @@ class BertPreTrainedModel(nn.Module): serialization_dir = tempdir # Load config config_file = os.path.join(serialization_dir, CONFIG_NAME) + if not os.path.exists(config_file): + # Backward compatibility with old naming format + config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME) config = BertConfig.from_json_file(config_file) logger.info("Model config {}".format(config)) # Instantiate model. model = cls(config, *inputs, **kwargs) if state_dict is None and not from_tf: weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) - state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None) + state_dict = torch.load(weights_path, map_location='cpu') if tempdir: # Clean up temp dir shutil.rmtree(tempdir) diff --git a/pytorch_pretrained_bert/modeling_gpt2.py b/pytorch_pretrained_bert/modeling_gpt2.py index 7b00ce77300..7cf1e6b59de 100644 --- a/pytorch_pretrained_bert/modeling_gpt2.py +++ b/pytorch_pretrained_bert/modeling_gpt2.py @@ -34,7 +34,7 @@ import torch.nn as nn from torch.nn import CrossEntropyLoss from torch.nn.parameter import Parameter -from .file_utils import cached_path +from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME from .modeling import BertLayerNorm as LayerNorm logger = logging.getLogger(__name__) @@ -42,9 +42,6 @@ logger = logging.getLogger(__name__) PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"} PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"} -CONFIG_NAME = "config.json" -WEIGHTS_NAME = "pytorch_model.bin" - def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path): """ Load tf checkpoints in a pytorch model """ @@ -180,6 +177,11 @@ class GPT2Config(object): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + def to_json_file(self, json_file_path): + """ Save this instance to a json file.""" + with open(json_file_path, "w", encoding='utf-8') as writer: + writer.write(self.to_json_string()) + class Conv1D(nn.Module): def __init__(self, nf, nx): @@ -416,7 +418,7 @@ class GPT2PreTrainedModel(nn.Module): # Instantiate model. model = cls(config, *inputs, **kwargs) if state_dict is None and not from_tf: - state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None) + state_dict = torch.load(resolved_archive_file, map_location='cpu') if from_tf: # Directly load from a TensorFlow checkpoint (stored as NumPy array) return load_tf_weights_in_gpt2(model, resolved_archive_file) diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index 7b95d74f7c3..f956462ddbf 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -34,7 +34,7 @@ import torch.nn as nn from torch.nn import CrossEntropyLoss from torch.nn.parameter import Parameter -from .file_utils import cached_path +from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME from .modeling import BertLayerNorm as LayerNorm logger = logging.getLogger(__name__) @@ -42,8 +42,6 @@ logger = logging.getLogger(__name__) PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"} PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"} -CONFIG_NAME = "config.json" -WEIGHTS_NAME = "pytorch_model.bin" def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path): """ Load tf pre-trained weights in a pytorch model (from NumPy arrays here) @@ -225,6 +223,11 @@ class OpenAIGPTConfig(object): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + def to_json_file(self, json_file_path): + """ Save this instance to a json file.""" + with open(json_file_path, "w", encoding='utf-8') as writer: + writer.write(self.to_json_string()) + class Conv1D(nn.Module): def __init__(self, nf, rf, nx): @@ -473,7 +476,7 @@ class OpenAIGPTPreTrainedModel(nn.Module): # Instantiate model. model = cls(config, *inputs, **kwargs) if state_dict is None and not from_tf: - state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None) + state_dict = torch.load(resolved_archive_file, map_location='cpu') if from_tf: # Directly load from a TensorFlow checkpoint (stored as NumPy array) return load_tf_weights_in_openai_gpt(model, resolved_archive_file) diff --git a/pytorch_pretrained_bert/modeling_transfo_xl.py b/pytorch_pretrained_bert/modeling_transfo_xl.py index ac895a03a7e..e8fffc5b608 100644 --- a/pytorch_pretrained_bert/modeling_transfo_xl.py +++ b/pytorch_pretrained_bert/modeling_transfo_xl.py @@ -40,7 +40,7 @@ from torch.nn.parameter import Parameter from .modeling import BertLayerNorm as LayerNorm from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits -from .file_utils import cached_path +from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME logger = logging.getLogger(__name__) @@ -50,8 +50,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = { PRETRAINED_CONFIG_ARCHIVE_MAP = { 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json", } -CONFIG_NAME = 'config.json' -WEIGHTS_NAME = 'pytorch_model.bin' + TF_WEIGHTS_NAME = 'model.ckpt' def build_tf_to_pytorch_map(model, config): @@ -316,6 +315,11 @@ class TransfoXLConfig(object): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + def to_json_file(self, json_file_path): + """ Save this instance to a json file.""" + with open(json_file_path, "w", encoding='utf-8') as writer: + writer.write(self.to_json_string()) + class PositionalEmbedding(nn.Module): def __init__(self, demb): @@ -940,7 +944,7 @@ class TransfoXLPreTrainedModel(nn.Module): # Instantiate model. model = cls(config, *inputs, **kwargs) if state_dict is None and not from_tf: - state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None) + state_dict = torch.load(resolved_archive_file, map_location='cpu') if from_tf: # Directly load from a TensorFlow checkpoint return load_tf_weights_in_transfo_xl(model, config, pretrained_model_name_or_path) diff --git a/pytorch_pretrained_bert/tokenization.py b/pytorch_pretrained_bert/tokenization.py index bbb3e25fc79..3937d6e0118 100644 --- a/pytorch_pretrained_bert/tokenization.py +++ b/pytorch_pretrained_bert/tokenization.py @@ -134,6 +134,21 @@ class BertTokenizer(object): tokens.append(self.ids_to_tokens[i]) return tokens + def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary to a directory or file.""" + index = 0 + if os.path.isdir(vocab_path): + vocab_file = os.path.join(vocab_path, VOCAB_NAME) + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!".format(vocab_file)) + index = token_index + writer.write(token + u'\n') + index += 1 + return vocab_file + @classmethod def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): """ diff --git a/pytorch_pretrained_bert/tokenization_gpt2.py b/pytorch_pretrained_bert/tokenization_gpt2.py index db95719dbcc..ab80876ee51 100644 --- a/pytorch_pretrained_bert/tokenization_gpt2.py +++ b/pytorch_pretrained_bert/tokenization_gpt2.py @@ -45,6 +45,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { } VOCAB_NAME = 'vocab.json' MERGES_NAME = 'merges.txt' +SPECIAL_TOKENS_NAME = 'special_tokens.txt' @lru_cache() def bytes_to_unicode(): @@ -97,6 +98,11 @@ class GPT2Tokenizer(object): else: vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) + special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) + if not os.path.exists(special_tokens_file): + special_tokens_file = None + else: + logger.info("loading special tokens file {}".format(special_tokens_file)) # redirect to the cache, if necessary try: resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) @@ -125,7 +131,11 @@ class GPT2Tokenizer(object): max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) # Instantiate tokenizer. - tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) + if special_tokens_file and 'special_tokens' not in kwargs: + special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] + else: + special_tokens = kwargs.pop('special_tokens', []) + tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) return tokenizer def __init__(self, vocab_file, merges_file, errors='replace', max_len=None): @@ -187,6 +197,35 @@ class GPT2Tokenizer(object): self.cache[token] = word return word + def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary and merge files to a directory.""" + if not os.path.isdir(vocab_path): + logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) + return + vocab_file = os.path.join(vocab_path, VOCAB_NAME) + merge_file = os.path.join(vocab_path, MERGES_NAME) + special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) + + with open(vocab_file, 'w', encoding='utf-8') as f: + f.write(json.dumps(self.encoder, ensure_ascii=False)) + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write(u'#version: 0.2\n') + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!".format(merge_file)) + index = token_index + writer.write(' '.join(bpe_tokens) + u'\n') + index += 1 + + with open(special_tokens_file, 'w', encoding='utf-8') as writer: + for token in sorted(self.special_tokens.keys(), key=lambda kv: kv[1]): + writer.write(token + u'\n') + + return vocab_file, merge_file, special_tokens_file + def encode(self, text): bpe_tokens = [] for token in re.findall(self.pat, text): diff --git a/pytorch_pretrained_bert/tokenization_openai.py b/pytorch_pretrained_bert/tokenization_openai.py index 240122d12df..7a10271175b 100644 --- a/pytorch_pretrained_bert/tokenization_openai.py +++ b/pytorch_pretrained_bert/tokenization_openai.py @@ -41,6 +41,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { } VOCAB_NAME = 'vocab.json' MERGES_NAME = 'merges.txt' +SPECIAL_TOKENS_NAME = 'special_tokens.txt' def get_pairs(word): """ @@ -86,9 +87,15 @@ class OpenAIGPTTokenizer(object): if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] + special_tokens_file = None else: vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) + special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) + if not os.path.exists(special_tokens_file): + special_tokens_file = None + else: + logger.info("loading special tokens file {}".format(special_tokens_file)) # redirect to the cache, if necessary try: resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) @@ -117,7 +124,11 @@ class OpenAIGPTTokenizer(object): max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) # Instantiate tokenizer. - tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) + if special_tokens_file and 'special_tokens' not in kwargs: + special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] + else: + special_tokens = kwargs.pop('special_tokens', []) + tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) return tokenizer def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None): @@ -261,3 +272,32 @@ class OpenAIGPTTokenizer(object): ).replace(" 's", "'s").replace(" t ", "'t ").replace(" s ", "'s ").replace(" m ", "'m " ).replace(" 've", "'ve") return out_string + + def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary and merge files to a directory.""" + if not os.path.isdir(vocab_path): + logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) + return + vocab_file = os.path.join(vocab_path, VOCAB_NAME) + merge_file = os.path.join(vocab_path, MERGES_NAME) + special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) + + with open(vocab_file, 'w', encoding='utf-8') as f: + f.write(json.dumps(self.encoder, ensure_ascii=False)) + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write(u'#version: 0.2\n') + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!".format(merge_file)) + index = token_index + writer.write(' '.join(bpe_tokens) + u'\n') + index += 1 + + with open(special_tokens_file, 'w', encoding='utf-8') as writer: + for token in sorted(self.special_tokens.keys(), key=lambda kv: kv[1]): + writer.write(token + u'\n') + + return vocab_file, merge_file, special_tokens_file diff --git a/pytorch_pretrained_bert/tokenization_transfo_xl.py b/pytorch_pretrained_bert/tokenization_transfo_xl.py index b5360c51843..ddebc57c106 100644 --- a/pytorch_pretrained_bert/tokenization_transfo_xl.py +++ b/pytorch_pretrained_bert/tokenization_transfo_xl.py @@ -63,7 +63,10 @@ class TransfoXLTokenizer(object): if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] else: - vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) + if os.path.isdir(pretrained_model_name_or_path): + vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) + else: + vocab_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) @@ -141,6 +144,14 @@ class TransfoXLTokenizer(object): else: raise ValueError('No token in vocabulary') + def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary to a directory or file.""" + index = 0 + if os.path.isdir(vocab_path): + vocab_file = os.path.join(vocab_path, VOCAB_NAME) + torch.save(self.__dict__, vocab_file) + return vocab_file + def build_vocab(self): if self.vocab_file: print('building vocab from {}'.format(self.vocab_file)) @@ -245,82 +256,24 @@ class TransfoXLTokenizer(object): def __len__(self): return len(self.idx2sym) - def _run_split_on_punc(self, text): - """Splits punctuation on a piece of text.""" - if text in self.never_split: - return [text] - chars = list(text) - i = 0 - start_new_word = True - output = [] - while i < len(chars): - char = chars[i] - if _is_punctuation(char): - output.append([char]) - start_new_word = True - else: - if start_new_word: - output.append([]) - start_new_word = False - output[-1].append(char) - i += 1 - - return ["".join(x) for x in output] - - def _run_strip_accents(self, text): - """Strips accents from a piece of text.""" - text = unicodedata.normalize("NFD", text) - output = [] - for char in text: - cat = unicodedata.category(char) - if cat == "Mn": - continue - output.append(char) - return "".join(output) - - def _clean_text(self, text): - """Performs invalid character removal and whitespace cleanup on text.""" - output = [] - for char in text: - cp = ord(char) - if cp == 0 or cp == 0xfffd or _is_control(char): - continue - if _is_whitespace(char): - output.append(" ") - else: - output.append(char) - return "".join(output) - - def whitespace_tokenize(self, text): - """Runs basic whitespace cleaning and splitting on a piece of text.""" - text = text.strip() - if not text: - return [] - if self.delimiter == '': - tokens = text - else: - tokens = text.split(self.delimiter) - return tokens - def tokenize(self, line, add_eos=False, add_double_eos=False): - line = self._clean_text(line) line = line.strip() + # convert to lower case + if self.lower_case: + line = line.lower() - symbols = self.whitespace_tokenize(line) - - split_symbols = [] - for symbol in symbols: - if self.lower_case and symbol not in self.never_split: - symbol = symbol.lower() - symbol = self._run_strip_accents(symbol) - split_symbols.extend(self._run_split_on_punc(symbol)) + # empty delimiter '' will evaluate False + if self.delimiter == '': + symbols = line + else: + symbols = line.split(self.delimiter) if add_double_eos: # lm1b - return [''] + split_symbols + [''] + return [''] + symbols + [''] elif add_eos: - return split_symbols + [''] + return symbols + [''] else: - return split_symbols + return symbols class LMOrderedIterator(object): @@ -631,42 +584,3 @@ def get_lm_corpus(datadir, dataset): torch.save(corpus, fn) return corpus - -def _is_whitespace(char): - """Checks whether `chars` is a whitespace character.""" - # \t, \n, and \r are technically contorl characters but we treat them - # as whitespace since they are generally considered as such. - if char == " " or char == "\t" or char == "\n" or char == "\r": - return True - cat = unicodedata.category(char) - if cat == "Zs": - return True - return False - - -def _is_control(char): - """Checks whether `chars` is a control character.""" - # These are technically control characters but we count them as whitespace - # characters. - if char == "\t" or char == "\n" or char == "\r": - return False - cat = unicodedata.category(char) - if cat.startswith("C"): - return True - return False - - -def _is_punctuation(char): - """Checks whether `chars` is a punctuation character.""" - cp = ord(char) - # We treat all non-letter/number ASCII as punctuation. - # Characters such as "^", "$", and "`" are not in the Unicode - # Punctuation class but we treat them as punctuation anyways, for - # consistency. - if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or - (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): - return True - cat = unicodedata.category(char) - if cat.startswith("P"): - return True - return False diff --git a/tests/modeling_gpt2_test.py b/tests/modeling_gpt2_test.py index 12a539c44b6..d5424220600 100644 --- a/tests/modeling_gpt2_test.py +++ b/tests/modeling_gpt2_test.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import unittest import json import random @@ -176,6 +177,14 @@ class GPT2ModelTest(unittest.TestCase): self.assertEqual(obj["vocab_size"], 99) self.assertEqual(obj["n_embd"], 37) + def test_config_to_json_file(self): + config_first = GPT2Config(vocab_size_or_config_json_file=99, n_embd=37) + json_file_path = "/tmp/config.json" + config_first.to_json_file(json_file_path) + config_second = GPT2Config.from_json_file(json_file_path) + os.remove(json_file_path) + self.assertEqual(config_second.to_dict(), config_first.to_dict()) + def run_tester(self, tester): config_and_inputs = tester.prepare_config_and_inputs() output_result = tester.create_gpt2_model(*config_and_inputs) diff --git a/tests/modeling_openai_test.py b/tests/modeling_openai_test.py index 1cc8b7d5dcb..db03bf792e2 100644 --- a/tests/modeling_openai_test.py +++ b/tests/modeling_openai_test.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import unittest import json import random @@ -188,6 +189,14 @@ class OpenAIGPTModelTest(unittest.TestCase): self.assertEqual(obj["vocab_size"], 99) self.assertEqual(obj["n_embd"], 37) + def test_config_to_json_file(self): + config_first = OpenAIGPTConfig(vocab_size_or_config_json_file=99, n_embd=37) + json_file_path = "/tmp/config.json" + config_first.to_json_file(json_file_path) + config_second = OpenAIGPTConfig.from_json_file(json_file_path) + os.remove(json_file_path) + self.assertEqual(config_second.to_dict(), config_first.to_dict()) + def run_tester(self, tester): config_and_inputs = tester.prepare_config_and_inputs() output_result = tester.create_openai_model(*config_and_inputs) diff --git a/tests/modeling_test.py b/tests/modeling_test.py index c7a031cfb04..02d7a13fdac 100644 --- a/tests/modeling_test.py +++ b/tests/modeling_test.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import unittest import json import random @@ -251,6 +252,14 @@ class BertModelTest(unittest.TestCase): self.assertEqual(obj["vocab_size"], 99) self.assertEqual(obj["hidden_size"], 37) + def test_config_to_json_file(self): + config_first = BertConfig(vocab_size_or_config_json_file=99, hidden_size=37) + json_file_path = "/tmp/config.json" + config_first.to_json_file(json_file_path) + config_second = BertConfig.from_json_file(json_file_path) + os.remove(json_file_path) + self.assertEqual(config_second.to_dict(), config_first.to_dict()) + def run_tester(self, tester): config_and_inputs = tester.prepare_config_and_inputs() output_result = tester.create_bert_model(*config_and_inputs) diff --git a/tests/modeling_transfo_xl_test.py b/tests/modeling_transfo_xl_test.py index 291d5d9d2af..a59d90b2055 100644 --- a/tests/modeling_transfo_xl_test.py +++ b/tests/modeling_transfo_xl_test.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import unittest import json import random @@ -186,6 +187,14 @@ class TransfoXLModelTest(unittest.TestCase): self.assertEqual(obj["n_token"], 96) self.assertEqual(obj["d_embed"], 37) + def test_config_to_json_file(self): + config_first = TransfoXLConfig(vocab_size_or_config_json_file=96, d_embed=37) + json_file_path = "/tmp/config.json" + config_first.to_json_file(json_file_path) + config_second = TransfoXLConfig.from_json_file(json_file_path) + os.remove(json_file_path) + self.assertEqual(config_second.to_dict(), config_first.to_dict()) + def run_tester(self, tester): config_and_inputs = tester.prepare_config_and_inputs() diff --git a/tests/tokenization_openai_test.py b/tests/tokenization_openai_test.py index 6213eb1b037..1f695cfb12c 100644 --- a/tests/tokenization_openai_test.py +++ b/tests/tokenization_openai_test.py @@ -52,5 +52,21 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): self.assertListEqual( tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/") + tokenizer.from_pretrained("/tmp/") + os.remove(vocab_file) + os.remove(merges_file) + + text = "lower" + bpe_tokens = ["low", "er"] + tokens = tokenizer.tokenize(text) + self.assertListEqual(tokens, bpe_tokens) + + input_tokens = tokens + [""] + input_bpe_tokens = [14, 15, 20] + self.assertListEqual( + tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + + if __name__ == '__main__': unittest.main() diff --git a/tests/tokenization_test.py b/tests/tokenization_test.py index 78e145ffd21..15cc7ccd820 100644 --- a/tests/tokenization_test.py +++ b/tests/tokenization_test.py @@ -46,6 +46,17 @@ class TokenizationTest(unittest.TestCase): self.assertListEqual( tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) + vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/") + tokenizer.from_pretrained(vocab_file) + os.remove(vocab_file) + + tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") + self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) + + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) + + def test_chinese(self): tokenizer = BasicTokenizer() diff --git a/tests/tokenization_transfo_xl_test.py b/tests/tokenization_transfo_xl_test.py index 9ff04f5f34d..1a805f11e6c 100644 --- a/tests/tokenization_transfo_xl_test.py +++ b/tests/tokenization_transfo_xl_test.py @@ -18,9 +18,7 @@ import os import unittest from io import open -from pytorch_pretrained_bert.tokenization_transfo_xl import (TransfoXLTokenizer, - _is_control, _is_punctuation, - _is_whitespace) +from pytorch_pretrained_bert.tokenization_transfo_xl import TransfoXLTokenizer class TransfoXLTokenizationTest(unittest.TestCase): @@ -37,54 +35,37 @@ class TransfoXLTokenizationTest(unittest.TestCase): tokenizer.build_vocab() os.remove(vocab_file) - tokens = tokenizer.tokenize(u" UNwant\u00E9d,running") + tokens = tokenizer.tokenize(u" UNwanted , running") self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) self.assertListEqual( tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) + vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/") + tokenizer.from_pretrained(vocab_file) + os.remove(vocab_file) + + tokens = tokenizer.tokenize(u" UNwanted , running") + self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) + + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) + + def test_full_tokenizer_lower(self): tokenizer = TransfoXLTokenizer(lower_case=True) self.assertListEqual( - tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), + tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), ["hello", "!", "how", "are", "you", "?"]) - self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) def test_full_tokenizer_no_lower(self): tokenizer = TransfoXLTokenizer(lower_case=False) self.assertListEqual( - tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), + tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]) - def test_is_whitespace(self): - self.assertTrue(_is_whitespace(u" ")) - self.assertTrue(_is_whitespace(u"\t")) - self.assertTrue(_is_whitespace(u"\r")) - self.assertTrue(_is_whitespace(u"\n")) - self.assertTrue(_is_whitespace(u"\u00A0")) - - self.assertFalse(_is_whitespace(u"A")) - self.assertFalse(_is_whitespace(u"-")) - - def test_is_control(self): - self.assertTrue(_is_control(u"\u0005")) - - self.assertFalse(_is_control(u"A")) - self.assertFalse(_is_control(u" ")) - self.assertFalse(_is_control(u"\t")) - self.assertFalse(_is_control(u"\r")) - - def test_is_punctuation(self): - self.assertTrue(_is_punctuation(u"-")) - self.assertTrue(_is_punctuation(u"$")) - self.assertTrue(_is_punctuation(u"`")) - self.assertTrue(_is_punctuation(u".")) - - self.assertFalse(_is_punctuation(u"A")) - self.assertFalse(_is_punctuation(u" ")) - if __name__ == '__main__': unittest.main()