mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge pull request #489 from huggingface/tokenization_serialization
Better serialization for Tokenizers and Configuration classes - Also fix #466
This commit is contained in:
commit
3d78e226e6
132
README.md
132
README.md
@ -131,6 +131,7 @@ This package comprises the following classes that can be imported in Python and
|
||||
- Configuration classes for BERT, OpenAI GPT and Transformer-XL (in the respective [`modeling.py`](./pytorch_pretrained_bert/modeling.py), [`modeling_openai.py`](./pytorch_pretrained_bert/modeling_openai.py), [`modeling_transfo_xl.py`](./pytorch_pretrained_bert/modeling_transfo_xl.py) files):
|
||||
- `BertConfig` - Configuration class to store the configuration of a `BertModel` with utilities to read and write from JSON configuration files.
|
||||
- `OpenAIGPTConfig` - Configuration class to store the configuration of a `OpenAIGPTModel` with utilities to read and write from JSON configuration files.
|
||||
- `GPT2Config` - Configuration class to store the configuration of a `GPT2Model` with utilities to read and write from JSON configuration files.
|
||||
- `TransfoXLConfig` - Configuration class to store the configuration of a `TransfoXLModel` with utilities to read and write from JSON configuration files.
|
||||
|
||||
The repository further comprises:
|
||||
@ -461,10 +462,12 @@ Here is a detailed documentation of the classes in the package and how to use th
|
||||
|
||||
| Sub-section | Description |
|
||||
|-|-|
|
||||
| [Loading Google AI's/OpenAI's pre-trained weights](#loading-google-ai-or-openai-pre-trained-weights-or-pytorch-dump) | How to load Google AI/OpenAI's pre-trained weight or a PyTorch saved instance |
|
||||
| [PyTorch models](#PyTorch-models) | API of the BERT, GPT, GPT-2 and Transformer-XL PyTorch model classes |
|
||||
| [Loading pre-trained weights](#loading-google-ai-or-openai-pre-trained-weights-or-pytorch-dump) | How to load Google AI/OpenAI's pre-trained weight or a PyTorch saved instance |
|
||||
| [Serialization best-practices](#serialization-best-practices) | How to save and reload a fine-tuned model |
|
||||
| [Configurations](#configurations) | API of the configuration classes for BERT, GPT, GPT-2 and Transformer-XL |
|
||||
| [Models](#models) | API of the PyTorch model classes for BERT, GPT, GPT-2 and Transformer-XL |
|
||||
| [Tokenizers](#tokenizers) | API of the tokenizers class for BERT, GPT, GPT-2 and Transformer-XL|
|
||||
| [Optimizers](#optimizerss) | API of the optimizers |
|
||||
| [Optimizers](#optimizers) | API of the optimizers |
|
||||
|
||||
### Loading Google AI or OpenAI pre-trained weights or PyTorch dump
|
||||
|
||||
@ -524,7 +527,101 @@ model = GPT2Model.from_pretrained('gpt2')
|
||||
|
||||
```
|
||||
|
||||
### PyTorch models
|
||||
### Serialization best-practices
|
||||
|
||||
This section explain how you can save and re-load a fine-tuned model (BERT, GPT, GPT-2 and Transformer-XL).
|
||||
There are three types of files you need to save to be able to reload a fine-tuned model:
|
||||
|
||||
- the model it-self which should be saved following PyTorch serialization [best practices](https://pytorch.org/docs/stable/notes/serialization.html#best-practices),
|
||||
- the configuration file of the model which is saved as a JSON file, and
|
||||
- the vocabulary (and the merges for the BPE-based models GPT and GPT-2).
|
||||
|
||||
Here is the recommended way of saving the model, configuration and vocabulary to an `output_dir` directory and reloading the model and tokenizer afterwards:
|
||||
|
||||
```python
|
||||
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
|
||||
|
||||
output_dir = "./models/"
|
||||
|
||||
# Step 1: Save a model, configuration and vocabulary that you have fine-tuned
|
||||
|
||||
# If we have a distributed model, save only the encapsulated model
|
||||
# (it was wrapped in PyTorch DistributedDataParallel or DataParallel)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
|
||||
output_config_file = os.path.join(output_dir, CONFIG_NAME)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
tokenizer.save_vocabulary(output_dir)
|
||||
|
||||
# Step 2: Re-load the saved model and vocabulary
|
||||
|
||||
# Example for a Bert model
|
||||
model = BertForQuestionAnswering.from_pretrained(output_dir)
|
||||
tokenizer = BertTokenizer.from_pretrained(output_dir, do_lower_case=args.do_lower_case) # Add specific options if needed
|
||||
# Example for a GPT model
|
||||
model = OpenAIGPTDoubleHeadsModel.from_pretrained(output_dir)
|
||||
tokenizer = OpenAIGPTTokenizer.from_pretrained(output_dir)
|
||||
```
|
||||
|
||||
Here is another way you can save and reload the model if you want to use specific paths for each type of files:
|
||||
|
||||
```python
|
||||
output_model_file = "./models/my_own_model_file.bin"
|
||||
output_config_file = "./models/my_own_config_file.bin"
|
||||
output_vocab_file = "./models/my_own_vocab_file.bin"
|
||||
|
||||
# Step 1: Save a model, configuration and vocabulary that you have fine-tuned
|
||||
|
||||
# If we have a distributed model, save only the encapsulated model
|
||||
# (it was wrapped in PyTorch DistributedDataParallel or DataParallel)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
tokenizer.save_vocabulary(output_vocab_file)
|
||||
|
||||
# Step 2: Re-load the saved model and vocabulary
|
||||
|
||||
# We didn't save using the predefined WEIGHTS_NAME, CONFIG_NAME names, we cannot load using `from_pretrained`.
|
||||
# Here is how to do it in this situation:
|
||||
|
||||
# Example for a Bert model
|
||||
config = BertConfig.from_json_file(output_config_file)
|
||||
model = BertForQuestionAnswering(config)
|
||||
state_dict = torch.load(output_model_file)
|
||||
model.load_state_dict(state_dict)
|
||||
tokenizer = BertTokenizer(output_vocab_file, do_lower_case=args.do_lower_case)
|
||||
|
||||
# Example for a GPT model
|
||||
config = OpenAIGPTConfig.from_json_file(output_config_file)
|
||||
model = OpenAIGPTDoubleHeadsModel(config)
|
||||
state_dict = torch.load(output_model_file)
|
||||
model.load_state_dict(state_dict)
|
||||
tokenizer = OpenAIGPTTokenizer(output_vocab_file)
|
||||
```
|
||||
|
||||
### Configurations
|
||||
|
||||
Models (BERT, GPT, GPT-2 and Transformer-XL) are defined and build from configuration classes which containes the parameters of the models (number of layers, dimensionalities...) and a few utilities to read and write from JSON configuration files. The respective configuration classes are:
|
||||
|
||||
- `BertConfig` for `BertModel` and BERT classes instances.
|
||||
- `OpenAIGPTConfig` for `OpenAIGPTModel` and OpenAI GPT classes instances.
|
||||
- `GPT2Config` for `GPT2Model` and OpenAI GPT-2 classes instances.
|
||||
- `TransfoXLConfig` for `TransfoXLModel` and Transformer-XL classes instances.
|
||||
|
||||
These configuration classes contains a few utilities to load and save configurations:
|
||||
|
||||
- `from_dict(cls, json_object)`: A class method to construct a configuration from a Python dictionary of parameters. Returns an instance of the configuration class.
|
||||
- `from_json_file(cls, json_file)`: A class method to construct a configuration from a json file of parameters. Returns an instance of the configuration class.
|
||||
- `to_dict()`: Serializes an instance to a Python dictionary. Returns a dictionary.
|
||||
- `to_json_string()`: Serializes an instance to a JSON string. Returns a string.
|
||||
- `to_json_file(json_file_path)`: Save an instance to a json file.
|
||||
|
||||
### Models
|
||||
|
||||
#### 1. `BertModel`
|
||||
|
||||
@ -796,8 +893,7 @@ This model *outputs*:
|
||||
- `multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices]
|
||||
- `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as a torch.FloatTensors. They can be reused to speed up sequential decoding (see the `run_gpt2.py` example).
|
||||
|
||||
|
||||
### Tokenizers:
|
||||
### Tokenizers
|
||||
|
||||
#### `BertTokenizer`
|
||||
|
||||
@ -816,6 +912,7 @@ and three methods:
|
||||
- `tokenize(text)`: convert a `str` in a list of `str` tokens by (1) performing basic tokenization and (2) WordPiece tokenization.
|
||||
- `convert_tokens_to_ids(tokens)`: convert a list of `str` tokens in a list of `int` indices in the vocabulary.
|
||||
- `convert_ids_to_tokens(tokens)`: convert a list of `int` indices in a list of `str` tokens in the vocabulary.
|
||||
- `save_vocabulary(directory_path)`: save the vocabulary file to `directory_path`. Return the path to the saved vocabulary file: `vocab_file_path`. The vocabulary can be reloaded with `BertTokenizer.from_pretrained('vocab_file_path')` or `BertTokenizer.from_pretrained('directory_path')`.
|
||||
|
||||
Please refer to the doc strings and code in [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) for the details of the `BasicTokenizer` and `WordpieceTokenizer` classes. In general it is recommended to use `BertTokenizer` unless you know what you are doing.
|
||||
|
||||
@ -837,6 +934,7 @@ and five methods:
|
||||
- `convert_ids_to_tokens(tokens)`: convert a list of `int` indices in a list of `str` tokens in the vocabulary.
|
||||
- `set_special_tokens(self, special_tokens)`: update the list of special tokens (see above arguments)
|
||||
- `decode(ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)`: decode a list of `int` indices in a string and do some post-processing if needed: (i) remove special tokens from the output and (ii) clean up tokenization spaces.
|
||||
- `save_vocabulary(directory_path)`: save the vocabulary, merge and special tokens files to `directory_path`. Return the path to the three files: `vocab_file_path`, `merge_file_path`, `special_tokens_file_path`. The vocabulary can be reloaded with `OpenAIGPTTokenizer.from_pretrained('directory_path')`.
|
||||
|
||||
Please refer to the doc strings and code in [`tokenization_openai.py`](./pytorch_pretrained_bert/tokenization_openai.py) for the details of the `OpenAIGPTTokenizer`.
|
||||
|
||||
@ -844,6 +942,8 @@ Please refer to the doc strings and code in [`tokenization_openai.py`](./pytorch
|
||||
|
||||
`TransfoXLTokenizer` perform word tokenization. This tokenizer can be used for adaptive softmax and has utilities for counting tokens in a corpus to create a vocabulary ordered by toekn frequency (for adaptive softmax). See the adaptive softmax paper ([Efficient softmax approximation for GPUs](http://arxiv.org/abs/1609.04309)) for more details.
|
||||
|
||||
The API is similar to the API of `BertTokenizer` (see above).
|
||||
|
||||
Please refer to the doc strings and code in [`tokenization_transfo_xl.py`](./pytorch_pretrained_bert/tokenization_transfo_xl.py) for the details of these additional methods in `TransfoXLTokenizer`.
|
||||
|
||||
#### `GPT2Tokenizer`
|
||||
@ -860,11 +960,11 @@ and two methods:
|
||||
|
||||
- `encode(text)`: convert a `str` in a list of `int` tokens by performing byte-level BPE.
|
||||
- `decode(tokens)`: convert back a list of `int` tokens in a `str`.
|
||||
- `save_vocabulary(directory_path)`: save the vocabulary, merge and special tokens files to `directory_path`. Return the path to the three files: `vocab_file_path`, `merge_file_path`, `special_tokens_file_path`. The vocabulary can be reloaded with `OpenAIGPTTokenizer.from_pretrained('directory_path')`.
|
||||
|
||||
Please refer to [`tokenization_gpt2.py`](./pytorch_pretrained_bert/tokenization_gpt2.py) for more details on the `GPT2Tokenizer`.
|
||||
|
||||
|
||||
### Optimizers:
|
||||
### Optimizers
|
||||
|
||||
#### `BertAdam`
|
||||
|
||||
@ -1174,18 +1274,20 @@ To get these results we used a combination of:
|
||||
|
||||
Here is the full list of hyper-parameters for this run:
|
||||
```bash
|
||||
export SQUAD_DIR=/path/to/SQUAD
|
||||
|
||||
python ./run_squad.py \
|
||||
--bert_model bert-large-uncased \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--do_lower_case \
|
||||
--train_file $SQUAD_TRAIN \
|
||||
--predict_file $SQUAD_EVAL \
|
||||
--train_file $SQUAD_DIR/train-v1.1.json \
|
||||
--predict_file $SQUAD_DIR/dev-v1.1.json \
|
||||
--learning_rate 3e-5 \
|
||||
--num_train_epochs 2 \
|
||||
--max_seq_length 384 \
|
||||
--doc_stride 128 \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--output_dir /tmp/debug_squad/ \
|
||||
--train_batch_size 24 \
|
||||
--gradient_accumulation_steps 2
|
||||
```
|
||||
@ -1194,18 +1296,20 @@ If you have a recent GPU (starting from NVIDIA Volta series), you should try **1
|
||||
|
||||
Here is an example of hyper-parameters for a FP16 run we tried:
|
||||
```bash
|
||||
export SQUAD_DIR=/path/to/SQUAD
|
||||
|
||||
python ./run_squad.py \
|
||||
--bert_model bert-large-uncased \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--do_lower_case \
|
||||
--train_file $SQUAD_TRAIN \
|
||||
--predict_file $SQUAD_EVAL \
|
||||
--train_file $SQUAD_DIR/train-v1.1.json \
|
||||
--predict_file $SQUAD_DIR/dev-v1.1.json \
|
||||
--learning_rate 3e-5 \
|
||||
--num_train_epochs 2 \
|
||||
--max_seq_length 384 \
|
||||
--doc_stride 128 \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--output_dir /tmp/debug_squad/ \
|
||||
--train_batch_size 24 \
|
||||
--fp16 \
|
||||
--loss_scale 128
|
||||
|
@ -35,14 +35,11 @@ from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from scipy.stats import pearsonr, spearmanr
|
||||
from sklearn.metrics import matthews_corrcoef, f1_score
|
||||
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -697,6 +694,11 @@ def main():
|
||||
n_gpu = 1
|
||||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
|
||||
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
|
||||
device, n_gpu, bool(args.local_rank != -1), args.fp16))
|
||||
|
||||
@ -857,18 +859,21 @@ def main():
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
# Save a trained model and the associated configuration
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Save a trained model, configuration and tokenizer
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
with open(output_config_file, 'w') as f:
|
||||
f.write(model_to_save.config.to_json_string())
|
||||
|
||||
# Load a trained model and config that you have fine-tuned
|
||||
config = BertConfig(output_config_file)
|
||||
model = BertForSequenceClassification(config, num_labels=num_labels)
|
||||
model.load_state_dict(torch.load(output_model_file))
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
tokenizer.save_vocabulary(args.output_dir)
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels)
|
||||
tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
else:
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
|
||||
model.to(device)
|
||||
|
@ -39,7 +39,8 @@ import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
|
||||
from pytorch_pretrained_bert import OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer, OpenAIAdam, cached_path
|
||||
from pytorch_pretrained_bert import (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer,
|
||||
OpenAIAdam, cached_path, WEIGHTS_NAME, CONFIG_NAME)
|
||||
|
||||
ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
|
||||
|
||||
@ -218,15 +219,20 @@ def main():
|
||||
|
||||
# Save a trained model
|
||||
if args.do_train:
|
||||
# Save a trained model, configuration and tokenizer
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||
config = model.config
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
|
||||
# Load a trained model that you have fine-tuned
|
||||
model_state_dict = torch.load(output_model_file)
|
||||
model = OpenAIGPTDoubleHeadsModel(config)
|
||||
model.load_state_dict(model_state_dict)
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
tokenizer.save_vocabulary(args.output_dir)
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = OpenAIGPTDoubleHeadsModel.from_pretrained(args.output_dir)
|
||||
tokenizer = OpenAIGPTTokenizer.from_pretrained(args.output_dir)
|
||||
model.to(device)
|
||||
|
||||
if args.do_eval:
|
||||
|
@ -34,8 +34,8 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig, WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig
|
||||
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||
from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
|
||||
BertTokenizer,
|
||||
@ -46,9 +46,6 @@ if sys.version_info[0] == 2:
|
||||
else:
|
||||
import pickle
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -837,7 +834,17 @@ def main():
|
||||
parser.add_argument('--null_score_diff_threshold',
|
||||
type=float, default=0.0,
|
||||
help="If null_score - best_non_null is greater than the threshold predict null.")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
@ -848,6 +855,11 @@ def main():
|
||||
n_gpu = 1
|
||||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
|
||||
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
|
||||
device, n_gpu, bool(args.local_rank != -1), args.fp16))
|
||||
|
||||
@ -983,7 +995,7 @@ def main():
|
||||
|
||||
model.train()
|
||||
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
|
||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
|
||||
if n_gpu == 1:
|
||||
batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
|
||||
input_ids, input_mask, segment_ids, start_positions, end_positions = batch
|
||||
@ -1008,19 +1020,21 @@ def main():
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if args.do_train:
|
||||
# Save a trained model and the associated configuration
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Save a trained model, configuration and tokenizer
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
with open(output_config_file, 'w') as f:
|
||||
f.write(model_to_save.config.to_json_string())
|
||||
|
||||
# Load a trained model and config that you have fine-tuned
|
||||
config = BertConfig(output_config_file)
|
||||
model = BertForQuestionAnswering(config)
|
||||
model.load_state_dict(torch.load(output_model_file))
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
tokenizer.save_vocabulary(args.output_dir)
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = BertForQuestionAnswering.from_pretrained(args.output_dir)
|
||||
tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
else:
|
||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
|
||||
|
||||
@ -1054,7 +1068,7 @@ def main():
|
||||
model.eval()
|
||||
all_results = []
|
||||
logger.info("Start evaluating")
|
||||
for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating", disable=args.local_rank not in [-1, 0]):
|
||||
if len(all_results) % 1000 == 0:
|
||||
logger.info("Processing example: %d" % (len(all_results)))
|
||||
input_ids = input_ids.to(device)
|
||||
|
@ -32,8 +32,8 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.modeling import (BertForMultipleChoice, BertConfig, WEIGHTS_NAME, CONFIG_NAME)
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_pretrained_bert.modeling import BertForMultipleChoice, BertConfig
|
||||
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
|
||||
@ -473,18 +473,20 @@ def main():
|
||||
|
||||
|
||||
if args.do_train:
|
||||
# Save a trained model and the associated configuration
|
||||
# Save a trained model, configuration and tokenizer
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
with open(output_config_file, 'w') as f:
|
||||
f.write(model_to_save.config.to_json_string())
|
||||
|
||||
# Load a trained model and config that you have fine-tuned
|
||||
config = BertConfig(output_config_file)
|
||||
model = BertForMultipleChoice(config, num_choices=4)
|
||||
model.load_state_dict(torch.load(output_model_file))
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
tokenizer.save_vocabulary(args.output_dir)
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = BertForMultipleChoice.from_pretrained(args.output_dir, num_choices=4)
|
||||
tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
else:
|
||||
model = BertForMultipleChoice.from_pretrained(args.bert_model, num_choices=4)
|
||||
model.to(device)
|
||||
|
@ -28,7 +28,7 @@ import math
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_pretrained_bert import TransfoXLLMHeadModel, TransfoXLCorpus
|
||||
from pytorch_pretrained_bert import TransfoXLLMHeadModel, TransfoXLCorpus, TransfoXLTokenizer
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
@ -80,6 +80,7 @@ def main():
|
||||
# The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax
|
||||
# and tokenizing the dataset
|
||||
# The pre-processed corpus is a convertion (using the conversion script )
|
||||
tokenizer = TransfoXLTokenizer.from_pretrained(args.model_name)
|
||||
corpus = TransfoXLCorpus.from_pretrained(args.model_name)
|
||||
ntokens = len(corpus.vocab)
|
||||
|
||||
|
@ -21,4 +21,4 @@ from .modeling_gpt2 import (GPT2Config, GPT2Model,
|
||||
from .optimization import BertAdam
|
||||
from .optimization_openai import OpenAIAdam
|
||||
|
||||
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path
|
||||
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME
|
||||
|
@ -33,6 +33,9 @@ except (AttributeError, ImportError):
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
||||
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
|
@ -32,7 +32,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -45,8 +45,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
|
||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
|
||||
}
|
||||
CONFIG_NAME = 'bert_config.json'
|
||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||
BERT_CONFIG_NAME = 'bert_config.json'
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
|
||||
def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
||||
@ -220,6 +219,11 @@ class BertConfig(object):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path):
|
||||
""" Save this instance to a json file."""
|
||||
with open(json_file_path, "w", encoding='utf-8') as writer:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
|
||||
except ImportError:
|
||||
@ -581,13 +585,16 @@ class BertPreTrainedModel(nn.Module):
|
||||
serialization_dir = tempdir
|
||||
# Load config
|
||||
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||
if not os.path.exists(config_file):
|
||||
# Backward compatibility with old naming format
|
||||
config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
|
||||
config = BertConfig.from_json_file(config_file)
|
||||
logger.info("Model config {}".format(config))
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None and not from_tf:
|
||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None)
|
||||
state_dict = torch.load(weights_path, map_location='cpu')
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
|
@ -34,7 +34,7 @@ import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -42,9 +42,6 @@ logger = logging.getLogger(__name__)
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"}
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"}
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
||||
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
|
||||
""" Load tf checkpoints in a pytorch model
|
||||
"""
|
||||
@ -180,6 +177,11 @@ class GPT2Config(object):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path):
|
||||
""" Save this instance to a json file."""
|
||||
with open(json_file_path, "w", encoding='utf-8') as writer:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
def __init__(self, nf, nx):
|
||||
@ -416,7 +418,7 @@ class GPT2PreTrainedModel(nn.Module):
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None and not from_tf:
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None)
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
|
||||
return load_tf_weights_in_gpt2(model, resolved_archive_file)
|
||||
|
@ -34,7 +34,7 @@ import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -42,8 +42,6 @@ logger = logging.getLogger(__name__)
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
||||
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
|
||||
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
|
||||
@ -225,6 +223,11 @@ class OpenAIGPTConfig(object):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path):
|
||||
""" Save this instance to a json file."""
|
||||
with open(json_file_path, "w", encoding='utf-8') as writer:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
def __init__(self, nf, rf, nx):
|
||||
@ -473,7 +476,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None and not from_tf:
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None)
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
|
||||
return load_tf_weights_in_openai_gpt(model, resolved_archive_file)
|
||||
|
@ -40,7 +40,7 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
|
||||
from .file_utils import cached_path
|
||||
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -50,8 +50,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
|
||||
}
|
||||
CONFIG_NAME = 'config.json'
|
||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
|
||||
def build_tf_to_pytorch_map(model, config):
|
||||
@ -316,6 +315,11 @@ class TransfoXLConfig(object):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path):
|
||||
""" Save this instance to a json file."""
|
||||
with open(json_file_path, "w", encoding='utf-8') as writer:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
def __init__(self, demb):
|
||||
@ -940,7 +944,7 @@ class TransfoXLPreTrainedModel(nn.Module):
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None and not from_tf:
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None)
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint
|
||||
return load_tf_weights_in_transfo_xl(model, config, pretrained_model_name_or_path)
|
||||
|
@ -134,6 +134,21 @@ class BertTokenizer(object):
|
||||
tokens.append(self.ids_to_tokens[i])
|
||||
return tokens
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary to a directory or file."""
|
||||
index = 0
|
||||
if os.path.isdir(vocab_path):
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
|
||||
" Please check that the vocabulary is not corrupted!".format(vocab_file))
|
||||
index = token_index
|
||||
writer.write(token + u'\n')
|
||||
index += 1
|
||||
return vocab_file
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
|
@ -45,6 +45,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||
}
|
||||
VOCAB_NAME = 'vocab.json'
|
||||
MERGES_NAME = 'merges.txt'
|
||||
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
@ -97,6 +98,11 @@ class GPT2Tokenizer(object):
|
||||
else:
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
|
||||
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
|
||||
if not os.path.exists(special_tokens_file):
|
||||
special_tokens_file = None
|
||||
else:
|
||||
logger.info("loading special tokens file {}".format(special_tokens_file))
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
@ -125,7 +131,11 @@ class GPT2Tokenizer(object):
|
||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
|
||||
if special_tokens_file and 'special_tokens' not in kwargs:
|
||||
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
|
||||
else:
|
||||
special_tokens = kwargs.pop('special_tokens', [])
|
||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
|
||||
return tokenizer
|
||||
|
||||
def __init__(self, vocab_file, merges_file, errors='replace', max_len=None):
|
||||
@ -187,6 +197,35 @@ class GPT2Tokenizer(object):
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||
if not os.path.isdir(vocab_path):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
||||
return
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
merge_file = os.path.join(vocab_path, MERGES_NAME)
|
||||
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
|
||||
|
||||
with open(vocab_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||
|
||||
index = 0
|
||||
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||
writer.write(u'#version: 0.2\n')
|
||||
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(merge_file))
|
||||
index = token_index
|
||||
writer.write(' '.join(bpe_tokens) + u'\n')
|
||||
index += 1
|
||||
|
||||
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
|
||||
for token in sorted(self.special_tokens.keys(), key=lambda kv: kv[1]):
|
||||
writer.write(token + u'\n')
|
||||
|
||||
return vocab_file, merge_file, special_tokens_file
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
for token in re.findall(self.pat, text):
|
||||
|
@ -41,6 +41,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||
}
|
||||
VOCAB_NAME = 'vocab.json'
|
||||
MERGES_NAME = 'merges.txt'
|
||||
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
|
||||
|
||||
def get_pairs(word):
|
||||
"""
|
||||
@ -86,9 +87,15 @@ class OpenAIGPTTokenizer(object):
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
special_tokens_file = None
|
||||
else:
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
|
||||
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
|
||||
if not os.path.exists(special_tokens_file):
|
||||
special_tokens_file = None
|
||||
else:
|
||||
logger.info("loading special tokens file {}".format(special_tokens_file))
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
@ -117,7 +124,11 @@ class OpenAIGPTTokenizer(object):
|
||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
|
||||
if special_tokens_file and 'special_tokens' not in kwargs:
|
||||
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
|
||||
else:
|
||||
special_tokens = kwargs.pop('special_tokens', [])
|
||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
|
||||
return tokenizer
|
||||
|
||||
def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None):
|
||||
@ -261,3 +272,32 @@ class OpenAIGPTTokenizer(object):
|
||||
).replace(" 's", "'s").replace(" t ", "'t ").replace(" s ", "'s ").replace(" m ", "'m "
|
||||
).replace(" 've", "'ve")
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||
if not os.path.isdir(vocab_path):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
||||
return
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
merge_file = os.path.join(vocab_path, MERGES_NAME)
|
||||
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
|
||||
|
||||
with open(vocab_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||
|
||||
index = 0
|
||||
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||
writer.write(u'#version: 0.2\n')
|
||||
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(merge_file))
|
||||
index = token_index
|
||||
writer.write(' '.join(bpe_tokens) + u'\n')
|
||||
index += 1
|
||||
|
||||
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
|
||||
for token in sorted(self.special_tokens.keys(), key=lambda kv: kv[1]):
|
||||
writer.write(token + u'\n')
|
||||
|
||||
return vocab_file, merge_file, special_tokens_file
|
||||
|
@ -63,7 +63,10 @@ class TransfoXLTokenizer(object):
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
else:
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
else:
|
||||
vocab_file = pretrained_model_name_or_path
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
@ -141,6 +144,14 @@ class TransfoXLTokenizer(object):
|
||||
else:
|
||||
raise ValueError('No <unkown> token in vocabulary')
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary to a directory or file."""
|
||||
index = 0
|
||||
if os.path.isdir(vocab_path):
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
torch.save(self.__dict__, vocab_file)
|
||||
return vocab_file
|
||||
|
||||
def build_vocab(self):
|
||||
if self.vocab_file:
|
||||
print('building vocab from {}'.format(self.vocab_file))
|
||||
@ -245,82 +256,24 @@ class TransfoXLTokenizer(object):
|
||||
def __len__(self):
|
||||
return len(self.idx2sym)
|
||||
|
||||
def _run_split_on_punc(self, text):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
if text in self.never_split:
|
||||
return [text]
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def whitespace_tokenize(self, text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
if self.delimiter == '':
|
||||
tokens = text
|
||||
else:
|
||||
tokens = text.split(self.delimiter)
|
||||
return tokens
|
||||
|
||||
def tokenize(self, line, add_eos=False, add_double_eos=False):
|
||||
line = self._clean_text(line)
|
||||
line = line.strip()
|
||||
# convert to lower case
|
||||
if self.lower_case:
|
||||
line = line.lower()
|
||||
|
||||
symbols = self.whitespace_tokenize(line)
|
||||
|
||||
split_symbols = []
|
||||
for symbol in symbols:
|
||||
if self.lower_case and symbol not in self.never_split:
|
||||
symbol = symbol.lower()
|
||||
symbol = self._run_strip_accents(symbol)
|
||||
split_symbols.extend(self._run_split_on_punc(symbol))
|
||||
# empty delimiter '' will evaluate False
|
||||
if self.delimiter == '':
|
||||
symbols = line
|
||||
else:
|
||||
symbols = line.split(self.delimiter)
|
||||
|
||||
if add_double_eos: # lm1b
|
||||
return ['<S>'] + split_symbols + ['<S>']
|
||||
return ['<S>'] + symbols + ['<S>']
|
||||
elif add_eos:
|
||||
return split_symbols + ['<eos>']
|
||||
return symbols + ['<eos>']
|
||||
else:
|
||||
return split_symbols
|
||||
return symbols
|
||||
|
||||
|
||||
class LMOrderedIterator(object):
|
||||
@ -631,42 +584,3 @@ def get_lm_corpus(datadir, dataset):
|
||||
torch.save(corpus, fn)
|
||||
|
||||
return corpus
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("C"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
||||
|
@ -16,6 +16,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import random
|
||||
@ -176,6 +177,14 @@ class GPT2ModelTest(unittest.TestCase):
|
||||
self.assertEqual(obj["vocab_size"], 99)
|
||||
self.assertEqual(obj["n_embd"], 37)
|
||||
|
||||
def test_config_to_json_file(self):
|
||||
config_first = GPT2Config(vocab_size_or_config_json_file=99, n_embd=37)
|
||||
json_file_path = "/tmp/config.json"
|
||||
config_first.to_json_file(json_file_path)
|
||||
config_second = GPT2Config.from_json_file(json_file_path)
|
||||
os.remove(json_file_path)
|
||||
self.assertEqual(config_second.to_dict(), config_first.to_dict())
|
||||
|
||||
def run_tester(self, tester):
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
output_result = tester.create_gpt2_model(*config_and_inputs)
|
||||
|
@ -16,6 +16,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import random
|
||||
@ -188,6 +189,14 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
||||
self.assertEqual(obj["vocab_size"], 99)
|
||||
self.assertEqual(obj["n_embd"], 37)
|
||||
|
||||
def test_config_to_json_file(self):
|
||||
config_first = OpenAIGPTConfig(vocab_size_or_config_json_file=99, n_embd=37)
|
||||
json_file_path = "/tmp/config.json"
|
||||
config_first.to_json_file(json_file_path)
|
||||
config_second = OpenAIGPTConfig.from_json_file(json_file_path)
|
||||
os.remove(json_file_path)
|
||||
self.assertEqual(config_second.to_dict(), config_first.to_dict())
|
||||
|
||||
def run_tester(self, tester):
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
output_result = tester.create_openai_model(*config_and_inputs)
|
||||
|
@ -16,6 +16,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import random
|
||||
@ -251,6 +252,14 @@ class BertModelTest(unittest.TestCase):
|
||||
self.assertEqual(obj["vocab_size"], 99)
|
||||
self.assertEqual(obj["hidden_size"], 37)
|
||||
|
||||
def test_config_to_json_file(self):
|
||||
config_first = BertConfig(vocab_size_or_config_json_file=99, hidden_size=37)
|
||||
json_file_path = "/tmp/config.json"
|
||||
config_first.to_json_file(json_file_path)
|
||||
config_second = BertConfig.from_json_file(json_file_path)
|
||||
os.remove(json_file_path)
|
||||
self.assertEqual(config_second.to_dict(), config_first.to_dict())
|
||||
|
||||
def run_tester(self, tester):
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
output_result = tester.create_bert_model(*config_and_inputs)
|
||||
|
@ -16,6 +16,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import random
|
||||
@ -186,6 +187,14 @@ class TransfoXLModelTest(unittest.TestCase):
|
||||
self.assertEqual(obj["n_token"], 96)
|
||||
self.assertEqual(obj["d_embed"], 37)
|
||||
|
||||
def test_config_to_json_file(self):
|
||||
config_first = TransfoXLConfig(vocab_size_or_config_json_file=96, d_embed=37)
|
||||
json_file_path = "/tmp/config.json"
|
||||
config_first.to_json_file(json_file_path)
|
||||
config_second = TransfoXLConfig.from_json_file(json_file_path)
|
||||
os.remove(json_file_path)
|
||||
self.assertEqual(config_second.to_dict(), config_first.to_dict())
|
||||
|
||||
def run_tester(self, tester):
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
|
||||
|
@ -52,5 +52,21 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||
tokenizer.from_pretrained("/tmp/")
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er</w>"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = tokens + ["<unk>"]
|
||||
input_bpe_tokens = [14, 15, 20]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -46,6 +46,17 @@ class TokenizationTest(unittest.TestCase):
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
|
||||
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||
tokenizer.from_pretrained(vocab_file)
|
||||
os.remove(vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
||||
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
|
||||
|
||||
def test_chinese(self):
|
||||
tokenizer = BasicTokenizer()
|
||||
|
||||
|
@ -18,9 +18,7 @@ import os
|
||||
import unittest
|
||||
from io import open
|
||||
|
||||
from pytorch_pretrained_bert.tokenization_transfo_xl import (TransfoXLTokenizer,
|
||||
_is_control, _is_punctuation,
|
||||
_is_whitespace)
|
||||
from pytorch_pretrained_bert.tokenization_transfo_xl import TransfoXLTokenizer
|
||||
|
||||
|
||||
class TransfoXLTokenizationTest(unittest.TestCase):
|
||||
@ -37,54 +35,37 @@ class TransfoXLTokenizationTest(unittest.TestCase):
|
||||
tokenizer.build_vocab()
|
||||
os.remove(vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"<unk> UNwant\u00E9d,running")
|
||||
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
|
||||
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
|
||||
|
||||
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||
tokenizer.from_pretrained(vocab_file)
|
||||
os.remove(vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
|
||||
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
|
||||
|
||||
|
||||
def test_full_tokenizer_lower(self):
|
||||
tokenizer = TransfoXLTokenizer(lower_case=True)
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||
tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
|
||||
["hello", "!", "how", "are", "you", "?"])
|
||||
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
|
||||
|
||||
def test_full_tokenizer_no_lower(self):
|
||||
tokenizer = TransfoXLTokenizer(lower_case=False)
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||
tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
|
||||
["HeLLo", "!", "how", "Are", "yoU", "?"])
|
||||
|
||||
def test_is_whitespace(self):
|
||||
self.assertTrue(_is_whitespace(u" "))
|
||||
self.assertTrue(_is_whitespace(u"\t"))
|
||||
self.assertTrue(_is_whitespace(u"\r"))
|
||||
self.assertTrue(_is_whitespace(u"\n"))
|
||||
self.assertTrue(_is_whitespace(u"\u00A0"))
|
||||
|
||||
self.assertFalse(_is_whitespace(u"A"))
|
||||
self.assertFalse(_is_whitespace(u"-"))
|
||||
|
||||
def test_is_control(self):
|
||||
self.assertTrue(_is_control(u"\u0005"))
|
||||
|
||||
self.assertFalse(_is_control(u"A"))
|
||||
self.assertFalse(_is_control(u" "))
|
||||
self.assertFalse(_is_control(u"\t"))
|
||||
self.assertFalse(_is_control(u"\r"))
|
||||
|
||||
def test_is_punctuation(self):
|
||||
self.assertTrue(_is_punctuation(u"-"))
|
||||
self.assertTrue(_is_punctuation(u"$"))
|
||||
self.assertTrue(_is_punctuation(u"`"))
|
||||
self.assertTrue(_is_punctuation(u"."))
|
||||
|
||||
self.assertFalse(_is_punctuation(u"A"))
|
||||
self.assertFalse(_is_punctuation(u" "))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user