mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
commit
e1bfad4846
67
README.md
67
README.md
@ -19,7 +19,7 @@ This implementation is provided with [Google's pre-trained models](https://githu
|
||||
|
||||
## Installation
|
||||
|
||||
This repo was tested on Python 3.6+ and PyTorch 0.4.1
|
||||
This repo was tested on Python 3.5+ and PyTorch 0.4.1/1.0.0
|
||||
|
||||
### With pip
|
||||
|
||||
@ -46,13 +46,13 @@ python -m pytest -sv tests/
|
||||
|
||||
This package comprises the following classes that can be imported in Python and are detailed in the [Doc](#doc) section of this readme:
|
||||
|
||||
- Seven PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
|
||||
- Eight PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
|
||||
- [`BertModel`](./pytorch_pretrained_bert/modeling.py#L537) - raw BERT Transformer model (**fully pre-trained**),
|
||||
- [`BertForMaskedLM`](./pytorch_pretrained_bert/modeling.py#L691) - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**),
|
||||
- [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L752) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
|
||||
- [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L620) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
|
||||
- [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L814) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
|
||||
- [`BertForMultipleChoice`](./pytorch_pretrained_bert/modeling.py#L880) - BERT Transformer with a multiple choice head on top (used for task like Swag) (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
|
||||
- [`BertForMultipleChoice`](./pytorch_pretrained_bert/modeling.py#L880) - BERT Transformer with a multiple choice head on top (used for task like Swag) (BERT Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**),
|
||||
- [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L949) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**),
|
||||
- [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L1015) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
|
||||
|
||||
@ -156,7 +156,7 @@ Here is a detailed documentation of the classes in the package and how to use th
|
||||
| Sub-section | Description |
|
||||
|-|-|
|
||||
| [Loading Google AI's pre-trained weigths](#Loading-Google-AIs-pre-trained-weigths-and-PyTorch-dump) | How to load Google AI's pre-trained weight or a PyTorch saved instance |
|
||||
| [PyTorch models](#PyTorch-models) | API of the seven PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering` |
|
||||
| [PyTorch models](#PyTorch-models) | API of the eight PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForMultipleChoice` or `BertForQuestionAnswering` |
|
||||
| [Tokenizer: `BertTokenizer`](#Tokenizer-BertTokenizer) | API of the `BertTokenizer` class|
|
||||
| [Optimizer: `BertAdam`](#Optimizer-BertAdam) | API of the `BertAdam` class |
|
||||
|
||||
@ -170,7 +170,7 @@ model = BERT_CLASS.from_pretrain(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None)
|
||||
|
||||
where
|
||||
|
||||
- `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the seven PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForTokenClassification` or `BertForQuestionAnswering`, and
|
||||
- `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the eight PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForTokenClassification`, `BertForMultipleChoice` or `BertForQuestionAnswering`, and
|
||||
- `PRE_TRAINED_MODEL_NAME_OR_PATH` is either:
|
||||
|
||||
- the shortcut name of a Google AI's pre-trained model selected in the list:
|
||||
@ -353,14 +353,13 @@ The optimizer accepts the following arguments:
|
||||
|
||||
BERT-base and BERT-large are respectively 110M and 340M parameters models and it can be difficult to fine-tune them on a single GPU with the recommended batch size for good performance (in most case a batch size of 32).
|
||||
|
||||
To help with fine-tuning these models, we have included five techniques that you can activate in the fine-tuning scripts [`run_classifier.py`](./examples/run_classifier.py) and [`run_squad.py`](./examples/run_squad.py): gradient-accumulation, multi-gpu training, distributed training, optimize on CPU and 16-bits training . For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month.
|
||||
To help with fine-tuning these models, we have included several techniques that you can activate in the fine-tuning scripts [`run_classifier.py`](./examples/run_classifier.py) and [`run_squad.py`](./examples/run_squad.py): gradient-accumulation, multi-gpu training, distributed training and 16-bits training . For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month.
|
||||
|
||||
Here is how to use these techniques in our scripts:
|
||||
|
||||
- **Gradient Accumulation**: Gradient accumulation can be used by supplying a integer greater than 1 to the `--gradient_accumulation_steps` argument. The batch at each step will be divided by this integer and gradient will be accumulated over `gradient_accumulation_steps` steps.
|
||||
- **Multi-GPU**: Multi-GPU is automatically activated when several GPUs are detected and the batches are splitted over the GPUs.
|
||||
- **Distributed training**: Distributed training can be activated by supplying an integer greater or equal to 0 to the `--local_rank` argument (see below).
|
||||
- **Optimize on CPU**: The Adam optimizer stores 2 moving average of the weights of the model. If you keep them on GPU 1 (typical behavior), your first GPU will have to store 3-times the size of the model. This is not optimal for large models like `BERT-large` and means your batch size is a lot lower than it could be. This option will perform the optimization and store the averages on the CPU/RAM to free more room on the GPU(s). As the most computational intensive operation is usually the backward pass, this doesn't have a significant impact on the training time. Activate this option with `--optimize_on_cpu` on the [`run_squad.py`](./examples/run_squad.py) script.
|
||||
- **16-bits training**: 16-bits training, also called mixed-precision training, can reduce the memory requirement of your model on the GPU by using half-precision training, basically allowing to double the batch size. If you have a recent GPU (starting from NVIDIA Volta architecture) you should see no decrease in speed. A good introduction to Mixed precision training can be found [here](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) and a full documentation is [here](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html). In our scripts, this option can be activated by setting the `--fp16` flag and you can play with loss scaling using the `--loss_scaling` flag (see the previously linked documentation for details on loss scaling). If the loss scaling is too high (`Nan` in the gradients) it will be automatically scaled down until the value is acceptable. The default loss scaling is 128 which behaved nicely in our tests.
|
||||
|
||||
Note: To use *Distributed Training*, you will need to run one training script on each of your machines. This can be done for example by running the following command on each server (see [the above mentioned blog post]((https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255)) for more details):
|
||||
@ -371,16 +370,21 @@ Where `$THIS_MACHINE_INDEX` is an sequential index assigned to each of your mach
|
||||
|
||||
### Fine-tuning with BERT: running the examples
|
||||
|
||||
We showcase the same examples as [the original implementation](https://github.com/google-research/bert/): fine-tuning a sequence-level classifier on the MRPC classification corpus and a token-level classifier on the question answering dataset SQuAD.
|
||||
We showcase several fine-tuning examples based on (and extended from) [the original implementation](https://github.com/google-research/bert/):
|
||||
|
||||
Before running these examples you should download the
|
||||
- a *sequence-level classifier* on the MRPC classification corpus,
|
||||
- a *token-level classifier* on the question answering dataset SQuAD, and
|
||||
- a *sequence-level multiple-choice classifier* on the SWAG classification corpus.
|
||||
|
||||
#### MRPC
|
||||
|
||||
This example code fine-tunes BERT on the Microsoft Research Paraphrase
|
||||
Corpus (MRPC) corpus and runs in less than 10 minutes on a single K-80 and in 27 seconds (!) on single tesla V100 16GB with apex installed.
|
||||
|
||||
Before running this example you should download the
|
||||
[GLUE data](https://gluebenchmark.com/tasks) by running
|
||||
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
|
||||
and unpack it to some directory `$GLUE_DIR`. Please also download the `BERT-Base`
|
||||
checkpoint, unzip it to some directory `$BERT_BASE_DIR`, and convert it to its PyTorch version as explained in the previous section.
|
||||
|
||||
This example code fine-tunes `BERT-Base` on the Microsoft Research Paraphrase
|
||||
Corpus (MRPC) corpus and runs in less than 10 minutes on a single K-80.
|
||||
and unpack it to some directory `$GLUE_DIR`.
|
||||
|
||||
```shell
|
||||
export GLUE_DIR=/path/to/glue
|
||||
@ -401,7 +405,29 @@ python run_classifier.py \
|
||||
|
||||
Our test ran on a few seeds with [the original implementation hyper-parameters](https://github.com/google-research/bert#sentence-and-sentence-pair-classification-tasks) gave evaluation results between 84% and 88%.
|
||||
|
||||
The second example fine-tunes `BERT-Base` on the SQuAD question answering task.
|
||||
**Fast run with apex and 16 bit precision: fine-tuning on MRPC in 27 seconds!**
|
||||
First install apex as indicated [here](https://github.com/NVIDIA/apex).
|
||||
Then run
|
||||
```shell
|
||||
export GLUE_DIR=/path/to/glue
|
||||
|
||||
python run_classifier.py \
|
||||
--task_name MRPC \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir $GLUE_DIR/MRPC/ \
|
||||
--bert_model bert-base-uncased \
|
||||
--max_seq_length 128 \
|
||||
--train_batch_size 32 \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--output_dir /tmp/mrpc_output/
|
||||
```
|
||||
|
||||
#### SQuAD
|
||||
|
||||
This example code fine-tunes BERT on the SQuAD dataset. It runs in 24 min (with BERT-base) or 68 min (with BERT-large) on a single tesla V100 16GB.
|
||||
|
||||
The data for SQuAD can be downloaded with the following links and should be saved in a `$SQUAD_DIR` directory.
|
||||
|
||||
@ -432,7 +458,9 @@ Training with the previous hyper-parameters gave us the following results:
|
||||
{"f1": 88.52381567990474, "exact_match": 81.22043519394512}
|
||||
```
|
||||
|
||||
The data for Swag can be downloaded by cloning the following [repository](https://github.com/rowanz/swagaf)
|
||||
#### SWAG
|
||||
|
||||
The data for SWAG can be downloaded by cloning the following [repository](https://github.com/rowanz/swagaf)
|
||||
|
||||
```shell
|
||||
export SWAG_DIR=/path/to/SWAG
|
||||
@ -440,17 +468,18 @@ export SWAG_DIR=/path/to/SWAG
|
||||
python run_swag.py \
|
||||
--bert_model bert-base-uncased \
|
||||
--do_train \
|
||||
--do_lower_case \
|
||||
--do_eval \
|
||||
--data_dir $SWAG_DIR/data
|
||||
--data_dir $SWAG_DIR/data \
|
||||
--train_batch_size 16 \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_seq_length 80 \
|
||||
--output_dir /tmp/swag_output/
|
||||
--output_dir /tmp/swag_output/ \
|
||||
--gradient_accumulation_steps 4
|
||||
```
|
||||
|
||||
Training with the previous hyper-parameters gave us the following results:
|
||||
Training with the previous hyper-parameters on a single GPU gave us the following results:
|
||||
```
|
||||
eval_accuracy = 0.8062081375587323
|
||||
eval_loss = 0.5966546792367169
|
||||
|
7
docker/Dockerfile
Normal file
7
docker/Dockerfile
Normal file
@ -0,0 +1,7 @@
|
||||
FROM pytorch/pytorch:latest
|
||||
|
||||
RUN git clone https://github.com/NVIDIA/apex.git && cd apex && python setup.py install --cuda_ext --cpp_ext
|
||||
|
||||
RUN pip install pytorch-pretrained-bert
|
||||
|
||||
WORKDIR /workspace
|
@ -168,7 +168,7 @@ def read_examples(input_file):
|
||||
"""Read a list of `InputExample`s from an input file."""
|
||||
examples = []
|
||||
unique_id = 0
|
||||
with open(input_file, "r") as reader:
|
||||
with open(input_file, "r", encoding='utf-8') as reader:
|
||||
while True:
|
||||
line = reader.readline()
|
||||
if not line:
|
||||
|
@ -36,13 +36,6 @@ from pytorch_pretrained_bert.modeling import BertForSequenceClassification
|
||||
from pytorch_pretrained_bert.optimization import BertAdam
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
|
||||
try:
|
||||
from apex.optimizers import FP16_Optimizer
|
||||
from apex.optimizers import FusedAdam
|
||||
from apex.parallel import DistributedDataParallel as DDP
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this.")
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
@ -98,7 +91,7 @@ class DataProcessor(object):
|
||||
@classmethod
|
||||
def _read_tsv(cls, input_file, quotechar=None):
|
||||
"""Reads a tab separated value file."""
|
||||
with open(input_file, "r") as f:
|
||||
with open(input_file, "r", encoding='utf-8') as f:
|
||||
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
|
||||
lines = []
|
||||
for line in reader:
|
||||
@ -329,7 +322,7 @@ def main():
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model checkpoints will be written.")
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--max_seq_length",
|
||||
@ -420,7 +413,8 @@ def main():
|
||||
n_gpu = 1
|
||||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))
|
||||
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
|
||||
device, n_gpu, bool(args.local_rank != -1), args.fp16))
|
||||
|
||||
if args.gradient_accumulation_steps < 1:
|
||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||
@ -467,6 +461,11 @@ def main():
|
||||
model.half()
|
||||
model.to(device)
|
||||
if args.local_rank != -1:
|
||||
try:
|
||||
from apex.parallel import DistributedDataParallel as DDP
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||
|
||||
model = DDP(model)
|
||||
elif n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
@ -482,6 +481,12 @@ def main():
|
||||
if args.local_rank != -1:
|
||||
t_total = t_total // torch.distributed.get_world_size()
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex.optimizers import FP16_Optimizer
|
||||
from apex.optimizers import FusedAdam
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||
|
||||
optimizer = FusedAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
bias_correction=False,
|
||||
@ -546,6 +551,16 @@ def main():
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
# Save a trained model
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
|
||||
# Load a trained model that you have fine-tuned
|
||||
model_state_dict = torch.load(output_model_file)
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
||||
model.to(device)
|
||||
|
||||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
eval_examples = processor.get_dev_examples(args.data_dir)
|
||||
eval_features = convert_examples_to_features(
|
||||
|
@ -39,13 +39,6 @@ from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
|
||||
from pytorch_pretrained_bert.optimization import BertAdam
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
|
||||
try:
|
||||
from apex.optimizers import FP16_Optimizer
|
||||
from apex.optimizers import FusedAdam
|
||||
from apex.parallel import DistributedDataParallel as DDP
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this.")
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
@ -115,7 +108,7 @@ class InputFeatures(object):
|
||||
|
||||
def read_squad_examples(input_file, is_training):
|
||||
"""Read a SQuAD json file into a list of SquadExample."""
|
||||
with open(input_file, "r") as reader:
|
||||
with open(input_file, "r", encoding='utf-8') as reader:
|
||||
input_data = json.load(reader)["data"]
|
||||
|
||||
def is_whitespace(c):
|
||||
@ -690,7 +683,7 @@ def main():
|
||||
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model checkpoints will be written.")
|
||||
help="The output directory where the model checkpoints and predictions will be written.")
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json")
|
||||
@ -764,7 +757,7 @@ def main():
|
||||
n_gpu = 1
|
||||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits trainiing: {}".format(
|
||||
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
|
||||
device, n_gpu, bool(args.local_rank != -1), args.fp16))
|
||||
|
||||
if args.gradient_accumulation_steps < 1:
|
||||
@ -813,6 +806,11 @@ def main():
|
||||
model.half()
|
||||
model.to(device)
|
||||
if args.local_rank != -1:
|
||||
try:
|
||||
from apex.parallel import DistributedDataParallel as DDP
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||
|
||||
model = DDP(model)
|
||||
elif n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
@ -834,6 +832,12 @@ def main():
|
||||
if args.local_rank != -1:
|
||||
t_total = t_total // torch.distributed.get_world_size()
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex.optimizers import FP16_Optimizer
|
||||
from apex.optimizers import FusedAdam
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||
|
||||
optimizer = FusedAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
bias_correction=False,
|
||||
@ -911,6 +915,16 @@ def main():
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
# Save a trained model
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
|
||||
# Load a trained model that you have fine-tuned
|
||||
model_state_dict = torch.load(output_model_file)
|
||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
||||
model.to(device)
|
||||
|
||||
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
eval_examples = read_squad_examples(
|
||||
input_file=args.predict_file, is_training=False)
|
||||
|
@ -1,5 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -99,7 +100,7 @@ class InputFeatures(object):
|
||||
|
||||
|
||||
def read_swag_examples(input_file, is_training):
|
||||
with open(input_file, 'r') as f:
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
reader = csv.reader(f)
|
||||
lines = list(reader)
|
||||
|
||||
@ -232,34 +233,10 @@ def select_field(features, field):
|
||||
for feature in features
|
||||
]
|
||||
|
||||
def copy_optimizer_params_to_model(named_params_model, named_params_optimizer):
|
||||
""" Utility function for optimize_on_cpu and 16-bits training.
|
||||
Copy the parameters optimized on CPU/RAM back to the model on GPU
|
||||
"""
|
||||
for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model):
|
||||
if name_opti != name_model:
|
||||
logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
|
||||
raise ValueError
|
||||
param_model.data.copy_(param_opti.data)
|
||||
|
||||
def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_nan=False):
|
||||
""" Utility function for optimize_on_cpu and 16-bits training.
|
||||
Copy the gradient of the GPU parameters to the CPU/RAMM copy of the model
|
||||
"""
|
||||
is_nan = False
|
||||
for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model):
|
||||
if name_opti != name_model:
|
||||
logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
|
||||
raise ValueError
|
||||
if param_model.grad is not None:
|
||||
if test_nan and torch.isnan(param_model.grad).sum() > 0:
|
||||
is_nan = True
|
||||
if param_opti.grad is None:
|
||||
param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size()))
|
||||
param_opti.grad.data.copy_(param_model.grad.data)
|
||||
else:
|
||||
param_opti.grad = None
|
||||
return is_nan
|
||||
def warmup_linear(x, warmup=0.002):
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return 1.0 - x
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -335,17 +312,15 @@ def main():
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument('--optimize_on_cpu',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Whether to perform optimization and keep the optimizer averages on CPU")
|
||||
parser.add_argument('--fp16',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Whether to use 16-bit float precision instead of 32-bit")
|
||||
parser.add_argument('--loss_scale',
|
||||
type=float, default=128,
|
||||
help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
|
||||
type=float, default=0,
|
||||
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
||||
"0 (default value): dynamic loss scaling.\n"
|
||||
"Positive power of 2: static loss scaling value.\n")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -353,14 +328,13 @@ def main():
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
else:
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
n_gpu = 1
|
||||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
if args.fp16:
|
||||
logger.info("16-bits training currently not supported in distributed training")
|
||||
args.fp16 = False # (see https://github.com/pytorch/pytorch/pull/13496)
|
||||
logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))
|
||||
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
|
||||
device, n_gpu, bool(args.local_rank != -1), args.fp16))
|
||||
|
||||
if args.gradient_accumulation_steps < 1:
|
||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||
@ -393,38 +367,55 @@ def main():
|
||||
# Prepare model
|
||||
model = BertForMultipleChoice.from_pretrained(args.bert_model,
|
||||
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank),
|
||||
num_choices = 4
|
||||
)
|
||||
num_choices=4)
|
||||
if args.fp16:
|
||||
model.half()
|
||||
model.to(device)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank)
|
||||
try:
|
||||
from apex.parallel import DistributedDataParallel as DDP
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||
|
||||
model = DDP(model)
|
||||
elif n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Prepare optimizer
|
||||
if args.fp16:
|
||||
param_optimizer = [(n, param.clone().detach().to('cpu').float().requires_grad_()) \
|
||||
for n, param in model.named_parameters()]
|
||||
elif args.optimize_on_cpu:
|
||||
param_optimizer = [(n, param.clone().detach().to('cpu').requires_grad_()) \
|
||||
for n, param in model.named_parameters()]
|
||||
else:
|
||||
param_optimizer = list(model.named_parameters())
|
||||
no_decay = ['bias', 'gamma', 'beta']
|
||||
param_optimizer = list(model.named_parameters())
|
||||
|
||||
# hack to remove pooler, which is not used
|
||||
# thus it produce None grad that break apex
|
||||
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
|
||||
|
||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
|
||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
t_total = num_train_steps
|
||||
if args.local_rank != -1:
|
||||
t_total = t_total // torch.distributed.get_world_size()
|
||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
warmup=args.warmup_proportion,
|
||||
t_total=t_total)
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex.optimizers import FP16_Optimizer
|
||||
from apex.optimizers import FusedAdam
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||
|
||||
optimizer = FusedAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
bias_correction=False,
|
||||
max_grad_norm=1.0)
|
||||
if args.loss_scale == 0:
|
||||
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
|
||||
else:
|
||||
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
|
||||
else:
|
||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
warmup=args.warmup_proportion,
|
||||
t_total=t_total)
|
||||
|
||||
global_step = 0
|
||||
if args.do_train:
|
||||
@ -461,30 +452,35 @@ def main():
|
||||
loss = loss * args.loss_scale
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
loss.backward()
|
||||
tr_loss += loss.item()
|
||||
nb_tr_examples += input_ids.size(0)
|
||||
nb_tr_steps += 1
|
||||
|
||||
if args.fp16:
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16 or args.optimize_on_cpu:
|
||||
if args.fp16 and args.loss_scale != 1.0:
|
||||
# scale down gradients for fp16 training
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad.data = param.grad.data / args.loss_scale
|
||||
is_nan = set_optimizer_params_grad(param_optimizer, model.named_parameters(), test_nan=True)
|
||||
if is_nan:
|
||||
logger.info("FP16 TRAINING: Nan in gradients, reducing loss scaling")
|
||||
args.loss_scale = args.loss_scale / 2
|
||||
model.zero_grad()
|
||||
continue
|
||||
optimizer.step()
|
||||
copy_optimizer_params_to_model(model.named_parameters(), param_optimizer)
|
||||
else:
|
||||
optimizer.step()
|
||||
model.zero_grad()
|
||||
# modify learning rate with special warm up BERT uses
|
||||
lr_this_step = args.learning_rate * warmup_linear(global_step/t_total, args.warmup_proportion)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr_this_step
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
# Save a trained model
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
|
||||
# Load a trained model that you have fine-tuned
|
||||
model_state_dict = torch.load(output_model_file)
|
||||
model = BertForMultipleChoice.from_pretrained(args.bert_model,
|
||||
state_dict=model_state_dict,
|
||||
num_choices=4)
|
||||
model.to(device)
|
||||
|
||||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True)
|
||||
eval_features = convert_examples_to_features(
|
||||
|
@ -1,3 +1,4 @@
|
||||
__version__ = "0.4.0"
|
||||
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
||||
from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
||||
BertForMaskedLM, BertForNextSentencePrediction,
|
||||
|
@ -50,7 +50,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
|
||||
name = name.split('/')
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if name[-1] in ["adam_v", "adam_m"]:
|
||||
if any(n in ["adam_v", "adam_m"] for n in name):
|
||||
print("Skipping {}".format("/".join(name)))
|
||||
continue
|
||||
pointer = model
|
||||
@ -59,9 +59,9 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
|
||||
l = re.split(r'_(\d+)', m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
if l[0] == 'kernel':
|
||||
if l[0] == 'kernel' or l[0] == 'gamma':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'output_bias':
|
||||
elif l[0] == 'output_bias' or l[0] == 'beta':
|
||||
pointer = getattr(pointer, 'bias')
|
||||
elif l[0] == 'output_weights':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
|
@ -227,7 +227,7 @@ def read_set_from_file(filename: str) -> Set[str]:
|
||||
Expected file format is one item per line.
|
||||
'''
|
||||
collection = set()
|
||||
with open(filename, 'r') as file_:
|
||||
with open(filename, 'r', encoding='utf-8') as file_:
|
||||
for line in file_:
|
||||
collection.add(line.rstrip())
|
||||
return collection
|
||||
|
@ -34,9 +34,6 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .file_utils import cached_path
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
@ -106,7 +103,7 @@ class BertConfig(object):
|
||||
initializing all weight matrices.
|
||||
"""
|
||||
if isinstance(vocab_size_or_config_json_file, str):
|
||||
with open(vocab_size_or_config_json_file, "r") as reader:
|
||||
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
||||
json_config = json.loads(reader.read())
|
||||
for key, value in json_config.items():
|
||||
self.__dict__[key] = value
|
||||
@ -137,7 +134,7 @@ class BertConfig(object):
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `BertConfig` from a json file of parameters."""
|
||||
with open(json_file, "r") as reader:
|
||||
with open(json_file, "r", encoding='utf-8') as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
|
||||
@ -448,9 +445,9 @@ class PreTrainedBertModel(nn.Module):
|
||||
module.bias.data.zero_()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||
Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
|
||||
Params:
|
||||
@ -464,6 +461,8 @@ class PreTrainedBertModel(nn.Module):
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `bert_config.json` a configuration file for the model
|
||||
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
||||
*inputs, **kwargs: additional input for the specific Bert class
|
||||
(ex: num_labels for BertForSequenceClassification)
|
||||
"""
|
||||
@ -505,22 +504,23 @@ class PreTrainedBertModel(nn.Module):
|
||||
logger.info("Model config {}".format(config))
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
state_dict = torch.load(weights_path)
|
||||
if state_dict is None:
|
||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
state_dict = torch.load(weights_path)
|
||||
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma','weight')
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta','bias')
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key]=state_dict.pop(old_key)
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
|
@ -25,9 +25,6 @@ import logging
|
||||
|
||||
from .file_utils import cached_path
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||
|
@ -1,6 +1,5 @@
|
||||
# This installs Pytorch for CUDA 8 only. If you are using a newer version,
|
||||
# please visit http://pytorch.org/ and install the relevant version.
|
||||
torch>=0.4.1,<0.5.0
|
||||
# PyTorch
|
||||
torch>=0.4.1
|
||||
# progress bars in model download and training scripts
|
||||
tqdm
|
||||
# Accessing files from S3 directly.
|
||||
|
39
setup.py
39
setup.py
@ -1,12 +1,47 @@
|
||||
"""
|
||||
Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py
|
||||
|
||||
To create the package for pypi.
|
||||
|
||||
1. Change the version in __init__.py and setup.py.
|
||||
|
||||
2. Commit these changes with the message: "Release: VERSION"
|
||||
|
||||
3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' "
|
||||
Push the tag to git: git push --tags origin master
|
||||
|
||||
4. Build both the sources and the wheel. Do not change anything in setup.py between
|
||||
creating the wheel and the source distribution (obviously).
|
||||
|
||||
For the wheel, run: "python setup.py bdist_wheel" in the top level allennlp directory.
|
||||
(this will build a wheel for the python version you use to build it - make sure you use python 3.x).
|
||||
|
||||
For the sources, run: "python setup.py sdist"
|
||||
You should now have a /dist directory with both .whl and .tar.gz source versions of allennlp.
|
||||
|
||||
5. Check that everything looks correct by uploading the package to the pypi test server:
|
||||
|
||||
twine upload dist/* -r pypitest
|
||||
(pypi suggest using twine as other methods upload files via plaintext.)
|
||||
|
||||
Check that you can install it in a virtualenv by running:
|
||||
pip install -i https://testpypi.python.org/pypi allennlp
|
||||
|
||||
6. Upload the final version to actual pypi:
|
||||
twine upload dist/* -r pypi
|
||||
|
||||
7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
|
||||
|
||||
"""
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
setup(
|
||||
name="pytorch_pretrained_bert",
|
||||
version="0.3.0",
|
||||
version="0.4.0",
|
||||
author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors",
|
||||
author_email="thomas@huggingface.co",
|
||||
description="PyTorch version of Google AI BERT model with script to load Google pre-trained models",
|
||||
long_description=open("README.md", "r").read(),
|
||||
long_description=open("README.md", "r", encoding='utf-8').read(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords='BERT NLP deep learning google',
|
||||
license='Apache',
|
||||
|
@ -32,7 +32,7 @@ class OptimizationTest(unittest.TestCase):
|
||||
def test_adam(self):
|
||||
w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
|
||||
target = torch.tensor([0.4, 0.2, -0.5])
|
||||
criterion = torch.nn.MSELoss(reduction='elementwise_mean')
|
||||
criterion = torch.nn.MSELoss()
|
||||
# No warmup, constant schedule, no gradient clipping
|
||||
optimizer = BertAdam(params=[w], lr=2e-1,
|
||||
weight_decay=0.0,
|
||||
|
Loading…
Reference in New Issue
Block a user