mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
update readme
This commit is contained in:
parent
886cb49792
commit
f920eff8c3
157
README.md
157
README.md
@ -1,36 +1,34 @@
|
||||
# PyTorch implementation of Google AI's BERT model with Google's pre-trained models
|
||||
# PyTorch Pretrained Bert
|
||||
|
||||
This repository contains an op-for-op PyTorch reimplementation of [Google's TensorFlow repository for the BERT model](https://github.com/google-research/bert) that was released together with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
|
||||
|
||||
This implementation can load any pre-trained TensorFlow checkpoint for BERT (in particular [Google's pre-trained models](https://github.com/google-research/bert)) and a conversion script is provided (see below).
|
||||
This implementation is provided with [Google's pre-trained models](https://github.com/google-research/bert)) and a conversion script to load any pre-trained TensorFlow checkpoint for BERT is also provided.
|
||||
|
||||
The code to use, in addition, [the Multilingual and Chinese models](https://github.com/google-research/bert/blob/master/multilingual.md) will be added later this week (it's actually just the tokenization code that needs to be updated).
|
||||
## Content
|
||||
|
||||
# Documentation
|
||||
|
||||
| Section | Content |
|
||||
| Section | Description |
|
||||
|-|-|
|
||||
| [Installation](#installation) | How to install the package |
|
||||
| [Content](#content) | Overview of the package |
|
||||
| [Usage](#usage) | Quickstart examples |
|
||||
| [Doc](#doc) | Detailed documentation |
|
||||
| [Examples](#examples) | Detailed examples on how to fine-tune Bert |
|
||||
| [Notebooks](#notebooks) | Introduction on the provided Jupyter Notebooks |
|
||||
| [TPU](#tup) | Notes on TPU support and pretraining scripts |
|
||||
| [Command-line interface](#Command-line-interface) | Convert a TensorFlow checkpoint in a PyTorch dump |
|
||||
| [Installation](##installation) | How to install the package |
|
||||
| [Overview](##overview) | Overview of the package |
|
||||
| [Usage](##usage) | Quickstart examples |
|
||||
| [Doc](##doc) | Detailed documentation |
|
||||
| [Examples](##examples) | Detailed examples on how to fine-tune Bert |
|
||||
| [Notebooks](##notebooks) | Introduction on the provided Jupyter Notebooks |
|
||||
| [TPU](##tup) | Notes on TPU support and pretraining scripts |
|
||||
| [Command-line interface](##Command-line-interface) | Convert a TensorFlow checkpoint in a PyTorch dump |
|
||||
|
||||
# Installation
|
||||
## Installation
|
||||
|
||||
This repo was tested on Python 3.5+ and PyTorch 0.4.1
|
||||
|
||||
## From pip
|
||||
### With pip
|
||||
|
||||
PyTorch pretrained bert can be installed by pip as follows:
|
||||
```bash
|
||||
pip install pytorch_pretrained_bert
|
||||
```
|
||||
|
||||
## From source
|
||||
### From source
|
||||
|
||||
Clone the repository and run:
|
||||
```bash
|
||||
@ -44,15 +42,15 @@ You can run the tests with the command:
|
||||
python -m pytest -sv tests/
|
||||
```
|
||||
|
||||
# Content
|
||||
## Overview
|
||||
|
||||
This package comprises the following classes that can be imported in Python and are detailed in the [Doc](#doc) section of this readme:
|
||||
This package comprises the following classes that can be imported in Python and are detailed in the [Doc](##doc) section of this readme:
|
||||
|
||||
- Six PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights:
|
||||
- `BertModel` - raw BERT Transformer model (**fully pre-trained**),
|
||||
- `BertForMaskedLM` - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**),
|
||||
- `BertForNextSentencePrediction` - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
|
||||
- `BertForPretraining` - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
|
||||
- `BertForPreTraining` - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
|
||||
- `BertForSequenceClassification` - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
|
||||
- `BertForQuestionAnswering` - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
|
||||
|
||||
@ -74,62 +72,87 @@ The repository further comprises:
|
||||
- [`run_classifier.py`](./examples/run_classifier.py) - Show how to fine-tune an instance of `BertForSequenceClassification` on GLUE's MRPC task,
|
||||
- [`run_squad.py`](./examples/run_squad.py) - Show how to fine-tune an instance of `BertForQuestionAnswering` on SQuAD v1.0 task.
|
||||
|
||||
These examples are detailed in the [Examples](#examples) section of this readme.
|
||||
These examples are detailed in the [Examples](##examples) section of this readme.
|
||||
|
||||
- Three notebooks that were used to check that the TensorFlow and PyTorch models behave identically (in the [`notebooks` folder](./notebooks)):
|
||||
- [`Comparing-TF-and-PT-models.ipynb`](./notebooks/Comparing-TF-and-PT-models.ipynb) - Compare the hidden states predicted by `BertModel`,
|
||||
- [`Comparing-TF-and-PT-models-SQuAD.ipynb`](./notebooks/Comparing-TF-and-PT-models-SQuAD.ipynb) - Compare the spans predicted by `BertForQuestionAnswering` instances,
|
||||
- [`Comparing-TF-and-PT-models-MLM-NSP.ipynb`](./notebooks/Comparing-TF-and-PT-models-MLM-NSP.ipynb) - Compare the predictions of the `BertForPretraining` instances.
|
||||
|
||||
These notebooks are detailed in the [Notebooks](#notebooks) section of this readme.
|
||||
These notebooks are detailed in the [Notebooks](##notebooks) section of this readme.
|
||||
|
||||
- A command-line interface to convert any TensorFlow checkpoint in a PyTorch dump:
|
||||
|
||||
This CLI is detailed in the [Command-line interface](#Command-line-interface) section of this readme.
|
||||
This CLI is detailed in the [Command-line interface](##Command-line-interface) section of this readme.
|
||||
|
||||
# Usage
|
||||
## Usage
|
||||
|
||||
Here is a quick-start example using the `BertForMaskedLM` class with Google AI's pre-trained `Bert base uncased` model:
|
||||
Here is a quick-start example using `BertTokenizer`, `BertModel` and `BertForMaskedLM` class with Google AI's pre-trained `Bert base uncased` model. See the [doc section](##doc) below for all the details on these classes.
|
||||
|
||||
First let's prepare a tokenized input with `BertTokenizer`
|
||||
|
||||
```python
|
||||
import torch
|
||||
from pytorch_pretrained_bert import BertForMaskedLM, BertTokenizer
|
||||
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
|
||||
|
||||
# Load pre-trained model and tokenizer (weights and vocabulary)
|
||||
# Load pre-trained model tokenizer (vocabulary)
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||
|
||||
# Prepare tokenized input with a masked token
|
||||
# Tokenized input
|
||||
tokenized_text = "Who was Jim Henson ? Jim Henson was a puppeteer"
|
||||
tokenized_text = tokenizer.tokenize(text)
|
||||
|
||||
# Mask a token that we will try to predict back with `BertForMaskedLM`
|
||||
masked_index = 6
|
||||
tokenized_text[masked_index] = '[MASK]'
|
||||
assert tokenized_text == ['who', 'was', 'jim', 'henson', '?', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer']
|
||||
|
||||
# Convert token to vocabulary indices
|
||||
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
|
||||
# Assign sentence A and sentence B indices to 1st (resp 2nd) sentences
|
||||
# Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
|
||||
segments_ids = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
|
||||
|
||||
# Predict masked tokens with model
|
||||
# Convert inputs to PyTorch tensors
|
||||
tokens_tensor = torch.tensor([indexed_tokens])
|
||||
segments_tensors = torch.tensor([segments_ids])
|
||||
```
|
||||
|
||||
Let's see how to use `BertModel` to get hidden states
|
||||
|
||||
```python
|
||||
# Load pre-trained model (weights)
|
||||
model = BertModel.from_pretrained('bert-base-uncased')
|
||||
model.eval()
|
||||
|
||||
# Predict hidden states features for each layer
|
||||
encoded_layers, _ = model(tokens_tensor, segments_tensors)
|
||||
# We have a hidden states for each of the 12 layers in model bert-base-uncased
|
||||
assert len(encoded_layers) == 12
|
||||
```
|
||||
|
||||
And how to use `BertForMaskedLM`
|
||||
|
||||
```python
|
||||
# Load pre-trained model (weights)
|
||||
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||
model.eval()
|
||||
|
||||
# Predict all tokens
|
||||
predictions = model(tokens_tensor, segments_tensors)
|
||||
|
||||
# Use model to predict
|
||||
# confirm we were able to predict 'henson'
|
||||
predicted_index = torch.argmax(predictions[0, masked_index]).item()
|
||||
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])
|
||||
assert predicted_token == 'henson'
|
||||
```
|
||||
|
||||
# Doc
|
||||
## Doc
|
||||
|
||||
Here is a detailed documentation of the classes in the package.
|
||||
Here is a detailed documentation of the classes in the package and how to use them.
|
||||
|
||||
## Loading pre-trained weigths
|
||||
### Loading Google AI's pre-trained weigths and PyTorch dump
|
||||
|
||||
To load Google AI's pre-trained weight, the PyTorch model classes and the tokenizer can be instantiated as
|
||||
To load Google AI's pre-trained weight or a PyTorch saved instance of `BertForPreTraining`, the PyTorch model classes and the tokenizer can be instantiated as
|
||||
|
||||
```python
|
||||
model = BERT_CLASS.from_pretrain(PRE_TRAINED_MODEL_NAME_OR_PATH)
|
||||
@ -137,7 +160,7 @@ model = BERT_CLASS.from_pretrain(PRE_TRAINED_MODEL_NAME_OR_PATH)
|
||||
|
||||
where
|
||||
|
||||
- `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the six PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPretraining`, `BertForSequenceClassification` or `BertForQuestionAnswering` (to load the pre-trained weights), and
|
||||
- `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the six PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering`, and
|
||||
|
||||
- `PRE_TRAINED_MODEL_NAME` is either:
|
||||
|
||||
@ -160,9 +183,9 @@ Example:
|
||||
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
||||
```
|
||||
|
||||
## PyTorch models
|
||||
### PyTorch models
|
||||
|
||||
### 1. `BertModel`
|
||||
#### 1. `BertModel`
|
||||
|
||||
`BertModel` is the basic BERT Transformer model with a layer of summed token, position and sequence embeddings followed by a series of identical self-attention blocks (12 for BERT-base, 24 for BERT-large).
|
||||
|
||||
@ -186,14 +209,14 @@ This model *outputs* a tuple composed of:
|
||||
|
||||
An example on how to use this class is given in the `extract_features.py` script which can be used to extract the hidden states of the model for a given input.
|
||||
|
||||
### 2. `BertForPreTraining`
|
||||
#### 2. `BertForPreTraining`
|
||||
|
||||
`BertForPreTraining` includes the `BertModel` Transformer followed by the two pre-training heads:
|
||||
|
||||
- the masked language modeling head, and
|
||||
- the next sentence classification head.
|
||||
|
||||
*Inputs* comprises the inputs of the [`BertModel`](###-1.-`BertModel`) class plus two optional labels:
|
||||
*Inputs* comprises the inputs of the [`BertModel`](####-1.-`BertModel`) class plus two optional labels:
|
||||
|
||||
- `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size]
|
||||
- `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] with indices selected in [0, 1]. 0 => next sentence is the continuation, 1 => next sentence is a random sentence.
|
||||
@ -205,11 +228,11 @@ An example on how to use this class is given in the `extract_features.py` script
|
||||
- the masked language modeling logits, and
|
||||
- the next sentence classification logits.
|
||||
|
||||
### 3. `BertForMaskedLM`
|
||||
#### 3. `BertForMaskedLM`
|
||||
|
||||
`BertForMaskedLM` includes the `BertModel` Transformer followed by the (possibly) pre-trained masked language modeling head.
|
||||
|
||||
*Inputs* comprises the inputs of the [`BertModel`](###-1.-`BertModel`) class plus optional label:
|
||||
*Inputs* comprises the inputs of the [`BertModel`](####-1.-`BertModel`) class plus optional label:
|
||||
|
||||
- `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size]
|
||||
|
||||
@ -218,11 +241,11 @@ An example on how to use this class is given in the `extract_features.py` script
|
||||
- if `masked_lm_labels` is not `None`: Outputs the masked language modeling loss.
|
||||
- if `masked_lm_labels` is `None`: Outputs the masked language modeling logits.
|
||||
|
||||
### 4. `BertForNextSentencePrediction`
|
||||
#### 4. `BertForNextSentencePrediction`
|
||||
|
||||
`BertForNextSentencePrediction` includes the `BertModel` Transformer followed by the next sentence classification head.
|
||||
|
||||
*Inputs* comprises the inputs of the [`BertModel`](###-1.-`BertModel`) class plus an optional label:
|
||||
*Inputs* comprises the inputs of the [`BertModel`](####-1.-`BertModel`) class plus an optional label:
|
||||
|
||||
- `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] with indices selected in [0, 1]. 0 => next sentence is the continuation, 1 => next sentence is a random sentence.
|
||||
|
||||
@ -231,7 +254,7 @@ An example on how to use this class is given in the `extract_features.py` script
|
||||
- if `next_sentence_label` is not `None`: Outputs the next sentence classification loss.
|
||||
- if `next_sentence_label` is `None`: Outputs the next sentence classification logits.
|
||||
|
||||
### 5. `BertForSequenceClassification`
|
||||
#### 5. `BertForSequenceClassification`
|
||||
|
||||
`BertForSequenceClassification` is a fine-tuning model that includes `BertModel` and a sequence-level (sequence or pair of sequences) classifier on top of the `BertModel`.
|
||||
|
||||
@ -239,7 +262,7 @@ The sequence-level classifier is a linear layer that takes as input the last hid
|
||||
|
||||
An example on how to use this class is given in the `run_classifier.py` script which can be used to fine-tune a single sequence (or pair of sequence) classifier using BERT, for example for the MRPC task.
|
||||
|
||||
### 6. `BertForQuestionAnswering`
|
||||
#### 6. `BertForQuestionAnswering`
|
||||
|
||||
`BertForQuestionAnswering` is a fine-tuning model that includes `BertModel` with a token-level classifiers on top of the full sequence of last hidden states.
|
||||
|
||||
@ -247,9 +270,7 @@ The token-level classifier takes as input the full sequence of the last hidden s
|
||||
|
||||
An example on how to use this class is given in the `run_squad.py` script which can be used to fine-tune a token classifier using BERT, for example for the SQuAD task.
|
||||
|
||||
## Tokenizers
|
||||
|
||||
### `BertTokenizer`
|
||||
### Tokenizer: `BertTokenizer`
|
||||
|
||||
`BertTokenizer` perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization.
|
||||
|
||||
@ -264,13 +285,9 @@ and three methods:
|
||||
- `convert_tokens_to_ids(tokens)`: convert a list of `str` tokens in a list of `int` indices in the vocabulary.
|
||||
- `convert_ids_to_tokens(tokens)`: convert a list of `int` indices in a list of `str` tokens in the vocabulary.
|
||||
|
||||
### `BasicTokenizer` and `WordpieceTokenizer`
|
||||
Please refer to the doc strings and code in [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) for the details of the `BasicTokenizer` and `WordpieceTokenizer` classes. In general it is recommended to use `BertTokenizer` unless you know what you are doing.
|
||||
|
||||
Please refer to the doc strings and code in [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) for the details of these classes. In general it is recommended to use `BertTokenizer` unless you know what you are doing.
|
||||
|
||||
## Optimizer
|
||||
|
||||
### `BERTAdam`
|
||||
### Optimizer: `BERTAdam`
|
||||
|
||||
`BERTAdam` is a `torch.optimizer` adapted to be closer to the optimizer used in the TensorFlow implementation of Bert. The differences with PyTorch Adam optimizer are the following:
|
||||
|
||||
@ -290,11 +307,9 @@ The optimizer accepts the following arguments:
|
||||
- `weight_decay_rate:` Weight decay. Default : 0.01
|
||||
- `max_grad_norm` : Maximum norm for the gradients (-1 means no clipping). Default : 1.0
|
||||
|
||||
# Examples
|
||||
## Examples
|
||||
|
||||
Fine-tuning the models
|
||||
|
||||
## Training large models: introduction, tools and examples
|
||||
### Training large models: introduction, tools and examples
|
||||
|
||||
BERT-base and BERT-large are respectively 110M and 340M parameters models and it can be difficult to fine-tune them on a single GPU with the recommended batch size for good performance (in most case a batch size of 32).
|
||||
|
||||
@ -314,7 +329,7 @@ python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=$TH
|
||||
```
|
||||
Where `$THIS_MACHINE_INDEX` is an sequential index assigned to each of your machine (0, 1, 2...) and the machine with rank 0 has an IP address `192.168.1.1` and an open port `1234`.
|
||||
|
||||
## Fine-tuning with BERT: running the examples
|
||||
### Fine-tuning with BERT: running the examples
|
||||
|
||||
We showcase the same examples as [the original implementation](https://github.com/google-research/bert/): fine-tuning a sequence-level classifier on the MRPC classification corpus and a token-level classifier on the question answering dataset SQuAD.
|
||||
|
||||
@ -381,7 +396,7 @@ Training with the previous hyper-parameters gave us the following results:
|
||||
{"f1": 88.52381567990474, "exact_match": 81.22043519394512}
|
||||
```
|
||||
|
||||
# Fine-tuning BERT-large on GPUs
|
||||
## Fine-tuning BERT-large on GPUs
|
||||
|
||||
The options we list above allow to fine-tune BERT-large rather easily on GPU(s) instead of the TPU used by the original implementation.
|
||||
|
||||
@ -443,25 +458,27 @@ The results were similar to the above FP32 results (actually slightly higher):
|
||||
{"exact_match": 84.65468306527909, "f1": 91.238669287002}
|
||||
```
|
||||
|
||||
# Notebooks
|
||||
## Notebooks
|
||||
|
||||
Comparing the PyTorch model and the TensorFlow model predictions
|
||||
|
||||
We also include [three Jupyter Notebooks](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/notebooks) that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model.
|
||||
|
||||
- The first NoteBook ([Comparing TF and PT models.ipynb](https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/notebooks/Comparing%20TF%20and%20PT%20models.ipynb)) extracts the hidden states of a full sequence on each layers of the TensorFlow and the PyTorch models and computes the standard deviation between them. In the given example, we get a standard deviation of 1.5e-7 to 9e-7 on the various hidden state of the models.
|
||||
- The first NoteBook ([Comparing-TF-and-PT-models.ipynb](./notebooks/Comparing-TF-and-PT-models.ipynb)) extracts the hidden states of a full sequence on each layers of the TensorFlow and the PyTorch models and computes the standard deviation between them. In the given example, we get a standard deviation of 1.5e-7 to 9e-7 on the various hidden state of the models.
|
||||
|
||||
- The second NoteBook ([Comparing TF and PT models SQuAD predictions.ipynb](https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/notebooks/Comparing%20TF%20and%20PT%20models%20SQuAD%20predictions.ipynb)) compares the loss computed by the TensorFlow and the PyTorch models for identical initialization of the fine-tuning layer of the `BertForQuestionAnswering` and computes the standard deviation between them. In the given example, we get a standard deviation of 2.5e-7 between the models.
|
||||
- The second NoteBook ([Comparing-TF-and-PT-models-SQuAD.ipynb](./notebooks/Comparing-TF-and-PT-models-SQuAD.ipynb)) compares the loss computed by the TensorFlow and the PyTorch models for identical initialization of the fine-tuning layer of the `BertForQuestionAnswering` and computes the standard deviation between them. In the given example, we get a standard deviation of 2.5e-7 between the models.
|
||||
|
||||
Please follow the instructions given in the notebooks to run and modify them. They can also be nice example on how to use the models in a simpler way than the full fine-tuning scripts we provide.
|
||||
- The third NoteBook ([Comparing-TF-and-PT-models-MLM-NSP.ipynb](./notebooks/Comparing-TF-and-PT-models-MLM-NSP.ipynb)) compares the predictions computed by the TensorFlow and the PyTorch models for masked token using the pre-trained masked language modeling model.
|
||||
|
||||
# Command-line interface
|
||||
Please follow the instructions given in the notebooks to run and modify them.
|
||||
|
||||
## Command-line interface
|
||||
|
||||
A command-line interface is provided to convert a TensorFlow checkpoint in a PyTorch checkpoint
|
||||
|
||||
You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script.
|
||||
|
||||
This script takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in `extract_features.py`, `run_classifier.py` and `run_squad.py`).
|
||||
This CLI takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in `extract_features.py`, `run_classifier.py` and `run_squad.py`).
|
||||
|
||||
You only need to run this conversion script **once** to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with `bert_model.ckpt`) but be sure to keep the configuration file (`bert_config.json`) and the vocabulary file (`vocab.txt`) as these are needed for the PyTorch model too.
|
||||
|
||||
@ -472,7 +489,7 @@ Here is an example of the conversion process for a pre-trained `BERT-Base Uncase
|
||||
```shell
|
||||
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
|
||||
|
||||
python convert_tf_checkpoint_to_pytorch.py \
|
||||
pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch \
|
||||
--tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \
|
||||
--bert_config_file $BERT_BASE_DIR/bert_config.json \
|
||||
--pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin
|
||||
@ -480,7 +497,7 @@ python convert_tf_checkpoint_to_pytorch.py \
|
||||
|
||||
You can download Google's pre-trained models for the conversion [here](https://github.com/google-research/bert#pre-trained-models).
|
||||
|
||||
# TPU
|
||||
## TPU
|
||||
|
||||
TPU support and pretraining scripts
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
1
setup.py
1
setup.py
@ -23,7 +23,6 @@ setup(
|
||||
tests_require=['pytest'],
|
||||
classifiers=[
|
||||
'Intended Audience :: Science/Research',
|
||||
'Development Status :: 1 - Alpha',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
|
Loading…
Reference in New Issue
Block a user