mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-20 04:58:22 +06:00
commit
66d50ca6ae
37
README.md
37
README.md
@ -14,7 +14,7 @@ This implementation is provided with [Google's pre-trained models](https://githu
|
|||||||
| [Doc](#doc) | Detailed documentation |
|
| [Doc](#doc) | Detailed documentation |
|
||||||
| [Examples](#examples) | Detailed examples on how to fine-tune Bert |
|
| [Examples](#examples) | Detailed examples on how to fine-tune Bert |
|
||||||
| [Notebooks](#notebooks) | Introduction on the provided Jupyter Notebooks |
|
| [Notebooks](#notebooks) | Introduction on the provided Jupyter Notebooks |
|
||||||
| [TPU](#tup) | Notes on TPU support and pretraining scripts |
|
| [TPU](#tpu) | Notes on TPU support and pretraining scripts |
|
||||||
| [Command-line interface](#Command-line-interface) | Convert a TensorFlow checkpoint in a PyTorch dump |
|
| [Command-line interface](#Command-line-interface) | Convert a TensorFlow checkpoint in a PyTorch dump |
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
@ -46,13 +46,14 @@ python -m pytest -sv tests/
|
|||||||
|
|
||||||
This package comprises the following classes that can be imported in Python and are detailed in the [Doc](#doc) section of this readme:
|
This package comprises the following classes that can be imported in Python and are detailed in the [Doc](#doc) section of this readme:
|
||||||
|
|
||||||
- Six PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
|
- Seven PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
|
||||||
- [`BertModel`](./pytorch_pretrained_bert/modeling.py#L535) - raw BERT Transformer model (**fully pre-trained**),
|
- [`BertModel`](./pytorch_pretrained_bert/modeling.py#L537) - raw BERT Transformer model (**fully pre-trained**),
|
||||||
- [`BertForMaskedLM`](./pytorch_pretrained_bert/modeling.py#L689) - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**),
|
- [`BertForMaskedLM`](./pytorch_pretrained_bert/modeling.py#L691) - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**),
|
||||||
- [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L750) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
|
- [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L752) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
|
||||||
- [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L618) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
|
- [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L620) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
|
||||||
- [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L812) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
|
- [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L814) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
|
||||||
- [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L877) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
|
- [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L880) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**),
|
||||||
|
- [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L946) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
|
||||||
|
|
||||||
- Three tokenizers (in the [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) file):
|
- Three tokenizers (in the [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) file):
|
||||||
- `BasicTokenizer` - basic tokenization (punctuation splitting, lower casing, etc.),
|
- `BasicTokenizer` - basic tokenization (punctuation splitting, lower casing, etc.),
|
||||||
@ -153,7 +154,7 @@ Here is a detailed documentation of the classes in the package and how to use th
|
|||||||
| Sub-section | Description |
|
| Sub-section | Description |
|
||||||
|-|-|
|
|-|-|
|
||||||
| [Loading Google AI's pre-trained weigths](#Loading-Google-AIs-pre-trained-weigths-and-PyTorch-dump) | How to load Google AI's pre-trained weight or a PyTorch saved instance |
|
| [Loading Google AI's pre-trained weigths](#Loading-Google-AIs-pre-trained-weigths-and-PyTorch-dump) | How to load Google AI's pre-trained weight or a PyTorch saved instance |
|
||||||
| [PyTorch models](#PyTorch-models) | API of the six PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering` |
|
| [PyTorch models](#PyTorch-models) | API of the seven PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering` |
|
||||||
| [Tokenizer: `BertTokenizer`](#Tokenizer-BertTokenizer) | API of the `BertTokenizer` class|
|
| [Tokenizer: `BertTokenizer`](#Tokenizer-BertTokenizer) | API of the `BertTokenizer` class|
|
||||||
| [Optimizer: `BertAdam`](#Optimizer-BertAdam) | API of the `BertAdam` class |
|
| [Optimizer: `BertAdam`](#Optimizer-BertAdam) | API of the `BertAdam` class |
|
||||||
|
|
||||||
@ -167,7 +168,7 @@ model = BERT_CLASS.from_pretrain(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None)
|
|||||||
|
|
||||||
where
|
where
|
||||||
|
|
||||||
- `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the six PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering`, and
|
- `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the seven PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForTokenClassification` or `BertForQuestionAnswering`, and
|
||||||
- `PRE_TRAINED_MODEL_NAME_OR_PATH` is either:
|
- `PRE_TRAINED_MODEL_NAME_OR_PATH` is either:
|
||||||
|
|
||||||
- the shortcut name of a Google AI's pre-trained model selected in the list:
|
- the shortcut name of a Google AI's pre-trained model selected in the list:
|
||||||
@ -175,7 +176,9 @@ where
|
|||||||
- `bert-base-uncased`: 12-layer, 768-hidden, 12-heads, 110M parameters
|
- `bert-base-uncased`: 12-layer, 768-hidden, 12-heads, 110M parameters
|
||||||
- `bert-large-uncased`: 24-layer, 1024-hidden, 16-heads, 340M parameters
|
- `bert-large-uncased`: 24-layer, 1024-hidden, 16-heads, 340M parameters
|
||||||
- `bert-base-cased`: 12-layer, 768-hidden, 12-heads , 110M parameters
|
- `bert-base-cased`: 12-layer, 768-hidden, 12-heads , 110M parameters
|
||||||
- `bert-base-multilingual`: 102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
|
- `bert-large-cased`: 24-layer, 1024-hidden, 16-heads, 340M parameters
|
||||||
|
- `bert-base-multilingual-uncased`: (Orig, not recommended) 102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
|
||||||
|
- `bert-base-multilingual-cased`: **(New, recommended)** 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
|
||||||
- `bert-base-chinese`: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters
|
- `bert-base-chinese`: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters
|
||||||
|
|
||||||
- a path or url to a pretrained model archive containing:
|
- a path or url to a pretrained model archive containing:
|
||||||
@ -186,6 +189,10 @@ where
|
|||||||
If `PRE_TRAINED_MODEL_NAME_OR_PATH` is a shortcut name, the pre-trained weights will be downloaded from AWS S3 (see the links [here](pytorch_pretrained_bert/modeling.py)) and stored in a cache folder to avoid future download (the cache folder can be found at `~/.pytorch_pretrained_bert/`).
|
If `PRE_TRAINED_MODEL_NAME_OR_PATH` is a shortcut name, the pre-trained weights will be downloaded from AWS S3 (see the links [here](pytorch_pretrained_bert/modeling.py)) and stored in a cache folder to avoid future download (the cache folder can be found at `~/.pytorch_pretrained_bert/`).
|
||||||
- `cache_dir` can be an optional path to a specific directory to download and cache the pre-trained model weights. This option is useful in particular when you are using distributed training: to avoid concurrent access to the same weights you can set for example `cache_dir='./pretrained_model_{}'.format(args.local_rank)` (see the section on distributed training for more information)
|
- `cache_dir` can be an optional path to a specific directory to download and cache the pre-trained model weights. This option is useful in particular when you are using distributed training: to avoid concurrent access to the same weights you can set for example `cache_dir='./pretrained_model_{}'.format(args.local_rank)` (see the section on distributed training for more information)
|
||||||
|
|
||||||
|
`Uncased` means that the text has been lowercased before WordPiece tokenization, e.g., `John Smith` becomes `john smith`. The Uncased model also strips out any accent markers. `Cased` means that the true case and accent markers are preserved. Typically, the Uncased model is better unless you know that case information is important for your task (e.g., Named Entity Recognition or Part-of-Speech tagging). For information about the Multilingual and Chinese model, see the [Multilingual README](https://github.com/google-research/bert/blob/master/multilingual.md) or the original TensorFlow repository.
|
||||||
|
|
||||||
|
**When using an `uncased model`, make sure to pass `--do_lower_case` to the training scripts. (Or pass `do_lower_case=True` directly to FullTokenizer if you're using your own script.)**
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
||||||
@ -271,7 +278,13 @@ The sequence-level classifier is a linear layer that takes as input the last hid
|
|||||||
|
|
||||||
An example on how to use this class is given in the `run_classifier.py` script which can be used to fine-tune a single sequence (or pair of sequence) classifier using BERT, for example for the MRPC task.
|
An example on how to use this class is given in the `run_classifier.py` script which can be used to fine-tune a single sequence (or pair of sequence) classifier using BERT, for example for the MRPC task.
|
||||||
|
|
||||||
#### 6. `BertForQuestionAnswering`
|
#### 6. `BertForTokenClassification`
|
||||||
|
|
||||||
|
`BertForTokenClassification` is a fine-tuning model that includes `BertModel` and a token-level classifier on top of the `BertModel`.
|
||||||
|
|
||||||
|
The token-level classifier is a linear layer that takes as input the last hidden state of the sequence.
|
||||||
|
|
||||||
|
#### 7. `BertForQuestionAnswering`
|
||||||
|
|
||||||
`BertForQuestionAnswering` is a fine-tuning model that includes `BertModel` with a token-level classifiers on top of the full sequence of last hidden states.
|
`BertForQuestionAnswering` is a fine-tuning model that includes `BertModel` with a token-level classifiers on top of the full sequence of last hidden states.
|
||||||
|
|
||||||
|
@ -199,6 +199,7 @@ def main():
|
|||||||
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
|
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
|
||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
|
parser.add_argument("--do_lower_case", default=False, action='store_true', help="Set this flag if you are using an uncased model.")
|
||||||
parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
|
parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
|
||||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
parser.add_argument("--max_seq_length", default=128, type=int,
|
||||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
|
help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
|
||||||
@ -227,7 +228,7 @@ def main():
|
|||||||
|
|
||||||
layer_indexes = [int(x) for x in args.layers.split(",")]
|
layer_indexes = [int(x) for x in args.layers.split(",")]
|
||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
|
|
||||||
examples = read_examples(args.input_file)
|
examples = read_examples(args.input_file)
|
||||||
|
|
||||||
|
@ -376,6 +376,10 @@ def main():
|
|||||||
default=False,
|
default=False,
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="Whether to run eval on the dev set.")
|
help="Whether to run eval on the dev set.")
|
||||||
|
parser.add_argument("--do_lower_case",
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help="Set this flag if you are using an uncased model.")
|
||||||
parser.add_argument("--train_batch_size",
|
parser.add_argument("--train_batch_size",
|
||||||
default=32,
|
default=32,
|
||||||
type=int,
|
type=int,
|
||||||
@ -473,7 +477,7 @@ def main():
|
|||||||
processor = processors[task_name]()
|
processor = processors[task_name]()
|
||||||
label_list = processor.get_labels()
|
label_list = processor.get_labels()
|
||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
|
|
||||||
train_examples = None
|
train_examples = None
|
||||||
num_train_steps = None
|
num_train_steps = None
|
||||||
@ -542,7 +546,7 @@ def main():
|
|||||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||||
batch = tuple(t.to(device) for t in batch)
|
batch = tuple(t.to(device) for t in batch)
|
||||||
input_ids, input_mask, segment_ids, label_ids = batch
|
input_ids, input_mask, segment_ids, label_ids = batch
|
||||||
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
|
loss = model(input_ids, segment_ids, input_mask, label_ids)
|
||||||
if n_gpu > 1:
|
if n_gpu > 1:
|
||||||
loss = loss.mean() # mean() to average on multi-gpu.
|
loss = loss.mean() # mean() to average on multi-gpu.
|
||||||
if args.fp16 and args.loss_scale != 1.0:
|
if args.fp16 and args.loss_scale != 1.0:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
||||||
from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
||||||
BertForMaskedLM, BertForNextSentencePrediction,
|
BertForMaskedLM, BertForNextSentencePrediction,
|
||||||
BertForSequenceClassification, BertForQuestionAnswering)
|
BertForSequenceClassification, BertForTokenClassification,
|
||||||
|
BertForQuestionAnswering)
|
||||||
from .optimization import BertAdam
|
from .optimization import BertAdam
|
||||||
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||||
|
@ -42,7 +42,9 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
|
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
|
||||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
|
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
|
||||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
|
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
|
||||||
'bert-base-multilingual': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual.tar.gz",
|
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
|
||||||
|
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
|
||||||
|
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
|
||||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
|
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
|
||||||
}
|
}
|
||||||
CONFIG_NAME = 'bert_config.json'
|
CONFIG_NAME = 'bert_config.json'
|
||||||
@ -476,7 +478,7 @@ class PreTrainedBertModel(nn.Module):
|
|||||||
"associated to this path or url.".format(
|
"associated to this path or url.".format(
|
||||||
pretrained_model_name,
|
pretrained_model_name,
|
||||||
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
||||||
pretrained_model_name))
|
archive_file))
|
||||||
return None
|
return None
|
||||||
if resolved_archive_file == archive_file:
|
if resolved_archive_file == archive_file:
|
||||||
logger.info("loading archive file {}".format(archive_file))
|
logger.info("loading archive file {}".format(archive_file))
|
||||||
@ -557,7 +559,7 @@ class BertModel(PreTrainedBertModel):
|
|||||||
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
|
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
|
||||||
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
|
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
|
||||||
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
|
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
|
||||||
to the last attention block,
|
to the last attention block of shape [batch_size, sequence_length, hidden_size],
|
||||||
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
|
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
|
||||||
classifier pretrained on top of the hidden state associated to the first character of the
|
classifier pretrained on top of the hidden state associated to the first character of the
|
||||||
input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
|
input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
|
||||||
@ -567,10 +569,10 @@ class BertModel(PreTrainedBertModel):
|
|||||||
# Already been converted into WordPiece token ids
|
# Already been converted into WordPiece token ids
|
||||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||||
|
|
||||||
config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
|
config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||||
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||||
|
|
||||||
model = modeling.BertModel(config=config)
|
model = modeling.BertModel(config=config)
|
||||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||||
@ -648,18 +650,18 @@ class BertForPreTraining(PreTrainedBertModel):
|
|||||||
sentence classification loss.
|
sentence classification loss.
|
||||||
if `masked_lm_labels` or `next_sentence_label` is `None`:
|
if `masked_lm_labels` or `next_sentence_label` is `None`:
|
||||||
Outputs a tuple comprising
|
Outputs a tuple comprising
|
||||||
- the masked language modeling logits, and
|
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
|
||||||
- the next sentence classification logits.
|
- the next sentence classification logits of shape [batch_size, 2].
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
```python
|
```python
|
||||||
# Already been converted into WordPiece token ids
|
# Already been converted into WordPiece token ids
|
||||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||||
|
|
||||||
config = BertConfig(vocab_size=32000, hidden_size=512,
|
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||||
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||||
|
|
||||||
model = BertForPreTraining(config)
|
model = BertForPreTraining(config)
|
||||||
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
|
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
|
||||||
@ -712,17 +714,17 @@ class BertForMaskedLM(PreTrainedBertModel):
|
|||||||
if `masked_lm_labels` is `None`:
|
if `masked_lm_labels` is `None`:
|
||||||
Outputs the masked language modeling loss.
|
Outputs the masked language modeling loss.
|
||||||
if `masked_lm_labels` is `None`:
|
if `masked_lm_labels` is `None`:
|
||||||
Outputs the masked language modeling logits.
|
Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
```python
|
```python
|
||||||
# Already been converted into WordPiece token ids
|
# Already been converted into WordPiece token ids
|
||||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||||
|
|
||||||
config = BertConfig(vocab_size=32000, hidden_size=512,
|
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||||
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||||
|
|
||||||
model = BertForMaskedLM(config)
|
model = BertForMaskedLM(config)
|
||||||
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
|
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
|
||||||
@ -774,7 +776,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel):
|
|||||||
Outputs the total_loss which is the sum of the masked language modeling loss and the next
|
Outputs the total_loss which is the sum of the masked language modeling loss and the next
|
||||||
sentence classification loss.
|
sentence classification loss.
|
||||||
if `next_sentence_label` is `None`:
|
if `next_sentence_label` is `None`:
|
||||||
Outputs the next sentence classification logits.
|
Outputs the next sentence classification logits of shape [batch_size, 2].
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
```python
|
```python
|
||||||
@ -783,8 +785,8 @@ class BertForNextSentencePrediction(PreTrainedBertModel):
|
|||||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||||
|
|
||||||
config = BertConfig(vocab_size=32000, hidden_size=512,
|
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||||
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||||
|
|
||||||
model = BertForNextSentencePrediction(config)
|
model = BertForNextSentencePrediction(config)
|
||||||
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
|
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
|
||||||
@ -836,17 +838,17 @@ class BertForSequenceClassification(PreTrainedBertModel):
|
|||||||
if `labels` is not `None`:
|
if `labels` is not `None`:
|
||||||
Outputs the CrossEntropy classification loss of the output with the labels.
|
Outputs the CrossEntropy classification loss of the output with the labels.
|
||||||
if `labels` is `None`:
|
if `labels` is `None`:
|
||||||
Outputs the classification logits.
|
Outputs the classification logits of shape [batch_size, num_labels].
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
```python
|
```python
|
||||||
# Already been converted into WordPiece token ids
|
# Already been converted into WordPiece token ids
|
||||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||||
|
|
||||||
config = BertConfig(vocab_size=32000, hidden_size=512,
|
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||||
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||||
|
|
||||||
num_labels = 2
|
num_labels = 2
|
||||||
|
|
||||||
@ -870,7 +872,73 @@ class BertForSequenceClassification(PreTrainedBertModel):
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
return loss, logits
|
return loss
|
||||||
|
else:
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class BertForTokenClassification(PreTrainedBertModel):
|
||||||
|
"""BERT model for token-level classification.
|
||||||
|
This module is composed of the BERT model with a linear layer on top of
|
||||||
|
the full hidden state of the last layer.
|
||||||
|
|
||||||
|
Params:
|
||||||
|
`config`: a BertConfig class instance with the configuration to build a new model.
|
||||||
|
`num_labels`: the number of classes for the classifier. Default = 2.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||||
|
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
||||||
|
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
||||||
|
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
||||||
|
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||||
|
a `sentence B` token (see BERT paper for more details).
|
||||||
|
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
||||||
|
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
||||||
|
input sequence length in the current batch. It's the mask that we typically use for attention when
|
||||||
|
a batch has varying length sentences.
|
||||||
|
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
|
||||||
|
with indices selected in [0, ..., num_labels].
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
if `labels` is not `None`:
|
||||||
|
Outputs the CrossEntropy classification loss of the output with the labels.
|
||||||
|
if `labels` is `None`:
|
||||||
|
Outputs the classification logits of shape [batch_size, sequence_length, num_labels].
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```python
|
||||||
|
# Already been converted into WordPiece token ids
|
||||||
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||||
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||||
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||||
|
|
||||||
|
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||||
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||||
|
|
||||||
|
num_labels = 2
|
||||||
|
|
||||||
|
model = BertForTokenClassification(config, num_labels)
|
||||||
|
logits = model(input_ids, token_type_ids, input_mask)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
def __init__(self, config, num_labels=2):
|
||||||
|
super(BertForTokenClassification, self).__init__(config)
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.bert = BertModel(config)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||||
|
self.apply(self.init_bert_weights)
|
||||||
|
|
||||||
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
|
||||||
|
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
|
||||||
|
sequence_output = self.dropout(sequence_output)
|
||||||
|
logits = self.classifier(sequence_output)
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
return loss
|
||||||
else:
|
else:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
@ -914,17 +982,17 @@ class BertForQuestionAnswering(PreTrainedBertModel):
|
|||||||
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
|
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
|
||||||
if `start_positions` or `end_positions` is `None`:
|
if `start_positions` or `end_positions` is `None`:
|
||||||
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
|
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
|
||||||
position tokens.
|
position tokens of shape [batch_size, sequence_length].
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
```python
|
```python
|
||||||
# Already been converted into WordPiece token ids
|
# Already been converted into WordPiece token ids
|
||||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||||
|
|
||||||
config = BertConfig(vocab_size=32000, hidden_size=512,
|
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||||
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||||
|
|
||||||
model = BertForQuestionAnswering(config)
|
model = BertForQuestionAnswering(config)
|
||||||
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
||||||
|
@ -34,9 +34,12 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
|||||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
|
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
|
||||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
|
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
|
||||||
'bert-base-multilingual': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-vocab.txt",
|
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
|
||||||
|
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
|
||||||
|
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
|
||||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
|
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
|
||||||
}
|
}
|
||||||
|
VOCAB_NAME = 'vocab.txt'
|
||||||
|
|
||||||
|
|
||||||
def load_vocab(vocab_file):
|
def load_vocab(vocab_file):
|
||||||
@ -98,7 +101,7 @@ class BertTokenizer(object):
|
|||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name, do_lower_case=True):
|
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
|
||||||
"""
|
"""
|
||||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||||
Download and cache the pre-trained model file if needed.
|
Download and cache the pre-trained model file if needed.
|
||||||
@ -107,16 +110,11 @@ class BertTokenizer(object):
|
|||||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
|
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
|
||||||
else:
|
else:
|
||||||
vocab_file = pretrained_model_name
|
vocab_file = pretrained_model_name
|
||||||
|
if os.path.isdir(vocab_file):
|
||||||
|
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_vocab_file = cached_path(vocab_file)
|
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||||
if resolved_vocab_file == vocab_file:
|
|
||||||
logger.info("loading vocabulary file {}".format(vocab_file))
|
|
||||||
else:
|
|
||||||
logger.info("loading vocabulary file {} from cache at {}".format(
|
|
||||||
vocab_file, resolved_vocab_file))
|
|
||||||
# Instantiate tokenizer.
|
|
||||||
tokenizer = cls(resolved_vocab_file, do_lower_case)
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Model name '{}' was not found in model name list ({}). "
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
@ -124,8 +122,15 @@ class BertTokenizer(object):
|
|||||||
"associated to this path or url.".format(
|
"associated to this path or url.".format(
|
||||||
pretrained_model_name,
|
pretrained_model_name,
|
||||||
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||||
pretrained_model_name))
|
vocab_file))
|
||||||
tokenizer = None
|
return None
|
||||||
|
if resolved_vocab_file == vocab_file:
|
||||||
|
logger.info("loading vocabulary file {}".format(vocab_file))
|
||||||
|
else:
|
||||||
|
logger.info("loading vocabulary file {} from cache at {}".format(
|
||||||
|
vocab_file, resolved_vocab_file))
|
||||||
|
# Instantiate tokenizer.
|
||||||
|
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
2
setup.py
2
setup.py
@ -2,7 +2,7 @@ from setuptools import find_packages, setup
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="pytorch_pretrained_bert",
|
name="pytorch_pretrained_bert",
|
||||||
version="0.2.0",
|
version="0.3.0",
|
||||||
author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors",
|
author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors",
|
||||||
author_email="thomas@huggingface.co",
|
author_email="thomas@huggingface.co",
|
||||||
description="PyTorch version of Google AI BERT model with script to load Google pre-trained models",
|
description="PyTorch version of Google AI BERT model with script to load Google pre-trained models",
|
||||||
|
@ -22,7 +22,10 @@ import random
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_pretrained_bert import BertConfig, BertModel
|
from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
|
||||||
|
BertForNextSentencePrediction, BertForPreTraining,
|
||||||
|
BertForQuestionAnswering, BertForSequenceClassification,
|
||||||
|
BertForTokenClassification)
|
||||||
|
|
||||||
|
|
||||||
class BertModelTest(unittest.TestCase):
|
class BertModelTest(unittest.TestCase):
|
||||||
@ -35,6 +38,7 @@ class BertModelTest(unittest.TestCase):
|
|||||||
is_training=True,
|
is_training=True,
|
||||||
use_input_mask=True,
|
use_input_mask=True,
|
||||||
use_token_type_ids=True,
|
use_token_type_ids=True,
|
||||||
|
use_labels=True,
|
||||||
vocab_size=99,
|
vocab_size=99,
|
||||||
hidden_size=32,
|
hidden_size=32,
|
||||||
num_hidden_layers=5,
|
num_hidden_layers=5,
|
||||||
@ -45,7 +49,9 @@ class BertModelTest(unittest.TestCase):
|
|||||||
attention_probs_dropout_prob=0.1,
|
attention_probs_dropout_prob=0.1,
|
||||||
max_position_embeddings=512,
|
max_position_embeddings=512,
|
||||||
type_vocab_size=16,
|
type_vocab_size=16,
|
||||||
|
type_sequence_label_size=2,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
|
num_labels=3,
|
||||||
scope=None):
|
scope=None):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@ -53,6 +59,7 @@ class BertModelTest(unittest.TestCase):
|
|||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
self.use_input_mask = use_input_mask
|
self.use_input_mask = use_input_mask
|
||||||
self.use_token_type_ids = use_token_type_ids
|
self.use_token_type_ids = use_token_type_ids
|
||||||
|
self.use_labels = use_labels
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
@ -63,10 +70,12 @@ class BertModelTest(unittest.TestCase):
|
|||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.type_vocab_size = type_vocab_size
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
|
self.num_labels = num_labels
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
|
|
||||||
def create_model(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
@ -77,6 +86,12 @@ class BertModelTest(unittest.TestCase):
|
|||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
token_type_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
token_type_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||||
|
|
||||||
|
sequence_labels = None
|
||||||
|
token_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||||
|
token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||||
|
|
||||||
config = BertConfig(
|
config = BertConfig(
|
||||||
vocab_size_or_config_json_file=self.vocab_size,
|
vocab_size_or_config_json_file=self.vocab_size,
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@ -90,10 +105,16 @@ class BertModelTest(unittest.TestCase):
|
|||||||
type_vocab_size=self.type_vocab_size,
|
type_vocab_size=self.type_vocab_size,
|
||||||
initializer_range=self.initializer_range)
|
initializer_range=self.initializer_range)
|
||||||
|
|
||||||
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels
|
||||||
|
|
||||||
|
def check_loss_output(self, result):
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["loss"].size()),
|
||||||
|
[])
|
||||||
|
|
||||||
|
def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
|
||||||
model = BertModel(config=config)
|
model = BertModel(config=config)
|
||||||
|
|
||||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||||
|
|
||||||
outputs = {
|
outputs = {
|
||||||
"sequence_output": all_encoder_layers[-1],
|
"sequence_output": all_encoder_layers[-1],
|
||||||
"pooled_output": pooled_output,
|
"pooled_output": pooled_output,
|
||||||
@ -101,13 +122,119 @@ class BertModelTest(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def check_output(self, result):
|
def check_bert_model_output(self, result):
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
[size for layer in result["all_encoder_layers"] for size in layer.size()],
|
||||||
|
[self.batch_size, self.seq_length, self.hidden_size] * self.num_hidden_layers)
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["sequence_output"].size()),
|
list(result["sequence_output"].size()),
|
||||||
[self.batch_size, self.seq_length, self.hidden_size])
|
[self.batch_size, self.seq_length, self.hidden_size])
|
||||||
|
|
||||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||||
|
|
||||||
|
|
||||||
|
def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
|
||||||
|
model = BertForMaskedLM(config=config)
|
||||||
|
loss = model(input_ids, token_type_ids, input_mask, token_labels)
|
||||||
|
prediction_scores = model(input_ids, token_type_ids, input_mask)
|
||||||
|
outputs = {
|
||||||
|
"loss": loss,
|
||||||
|
"prediction_scores": prediction_scores,
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def check_bert_for_masked_lm_output(self, result):
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["prediction_scores"].size()),
|
||||||
|
[self.batch_size, self.seq_length, self.vocab_size])
|
||||||
|
|
||||||
|
def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
|
||||||
|
model = BertForNextSentencePrediction(config=config)
|
||||||
|
loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
|
||||||
|
seq_relationship_score = model(input_ids, token_type_ids, input_mask)
|
||||||
|
outputs = {
|
||||||
|
"loss": loss,
|
||||||
|
"seq_relationship_score": seq_relationship_score,
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def check_bert_for_next_sequence_prediction_output(self, result):
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["seq_relationship_score"].size()),
|
||||||
|
[self.batch_size, 2])
|
||||||
|
|
||||||
|
|
||||||
|
def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
|
||||||
|
model = BertForPreTraining(config=config)
|
||||||
|
loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
|
||||||
|
prediction_scores, seq_relationship_score = model(input_ids, token_type_ids, input_mask)
|
||||||
|
outputs = {
|
||||||
|
"loss": loss,
|
||||||
|
"prediction_scores": prediction_scores,
|
||||||
|
"seq_relationship_score": seq_relationship_score,
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def check_bert_for_pretraining_output(self, result):
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["prediction_scores"].size()),
|
||||||
|
[self.batch_size, self.seq_length, self.vocab_size])
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["seq_relationship_score"].size()),
|
||||||
|
[self.batch_size, 2])
|
||||||
|
|
||||||
|
|
||||||
|
def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
|
||||||
|
model = BertForQuestionAnswering(config=config)
|
||||||
|
loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
|
||||||
|
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
||||||
|
outputs = {
|
||||||
|
"loss": loss,
|
||||||
|
"start_logits": start_logits,
|
||||||
|
"end_logits": end_logits,
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def check_bert_for_question_answering_output(self, result):
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["start_logits"].size()),
|
||||||
|
[self.batch_size, self.seq_length])
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["end_logits"].size()),
|
||||||
|
[self.batch_size, self.seq_length])
|
||||||
|
|
||||||
|
|
||||||
|
def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
|
||||||
|
model = BertForSequenceClassification(config=config, num_labels=self.num_labels)
|
||||||
|
loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
|
||||||
|
logits = model(input_ids, token_type_ids, input_mask)
|
||||||
|
outputs = {
|
||||||
|
"loss": loss,
|
||||||
|
"logits": logits,
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def check_bert_for_sequence_classification_output(self, result):
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["logits"].size()),
|
||||||
|
[self.batch_size, self.num_labels])
|
||||||
|
|
||||||
|
|
||||||
|
def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
|
||||||
|
model = BertForTokenClassification(config=config, num_labels=self.num_labels)
|
||||||
|
loss = model(input_ids, token_type_ids, input_mask, token_labels)
|
||||||
|
logits = model(input_ids, token_type_ids, input_mask)
|
||||||
|
outputs = {
|
||||||
|
"loss": loss,
|
||||||
|
"logits": logits,
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def check_bert_for_token_classification_output(self, result):
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["logits"].size()),
|
||||||
|
[self.batch_size, self.seq_length, self.num_labels])
|
||||||
|
|
||||||
|
|
||||||
def test_default(self):
|
def test_default(self):
|
||||||
self.run_tester(BertModelTest.BertModelTester(self))
|
self.run_tester(BertModelTest.BertModelTester(self))
|
||||||
|
|
||||||
@ -118,8 +245,33 @@ class BertModelTest(unittest.TestCase):
|
|||||||
self.assertEqual(obj["hidden_size"], 37)
|
self.assertEqual(obj["hidden_size"], 37)
|
||||||
|
|
||||||
def run_tester(self, tester):
|
def run_tester(self, tester):
|
||||||
output_result = tester.create_model()
|
config_and_inputs = tester.prepare_config_and_inputs()
|
||||||
tester.check_output(output_result)
|
output_result = tester.create_bert_model(*config_and_inputs)
|
||||||
|
tester.check_bert_model_output(output_result)
|
||||||
|
|
||||||
|
output_result = tester.create_bert_for_masked_lm(*config_and_inputs)
|
||||||
|
tester.check_bert_for_masked_lm_output(output_result)
|
||||||
|
tester.check_loss_output(output_result)
|
||||||
|
|
||||||
|
output_result = tester.create_bert_for_next_sequence_prediction(*config_and_inputs)
|
||||||
|
tester.check_bert_for_next_sequence_prediction_output(output_result)
|
||||||
|
tester.check_loss_output(output_result)
|
||||||
|
|
||||||
|
output_result = tester.create_bert_for_pretraining(*config_and_inputs)
|
||||||
|
tester.check_bert_for_pretraining_output(output_result)
|
||||||
|
tester.check_loss_output(output_result)
|
||||||
|
|
||||||
|
output_result = tester.create_bert_for_question_answering(*config_and_inputs)
|
||||||
|
tester.check_bert_for_question_answering_output(output_result)
|
||||||
|
tester.check_loss_output(output_result)
|
||||||
|
|
||||||
|
output_result = tester.create_bert_for_sequence_classification(*config_and_inputs)
|
||||||
|
tester.check_bert_for_sequence_classification_output(output_result)
|
||||||
|
tester.check_loss_output(output_result)
|
||||||
|
|
||||||
|
output_result = tester.create_bert_for_token_classification(*config_and_inputs)
|
||||||
|
tester.check_bert_for_token_classification_output(output_result)
|
||||||
|
tester.check_loss_output(output_result)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
|
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
|
||||||
|
Loading…
Reference in New Issue
Block a user