test suite independent of framework

This commit is contained in:
thomwolf 2019-09-05 11:18:55 +02:00
parent 9d0a11a68c
commit 518307dfcd
20 changed files with 596 additions and 262 deletions

View File

@ -10,7 +10,7 @@ jobs:
- checkout
- run: sudo pip install torch
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install pytest==5.0.1 codecov pytest-cov
- run: sudo pip install tensorboardX scikit-learn
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
- run: python -m pytest -sv ./examples/
@ -25,10 +25,9 @@ jobs:
- checkout
- run: sudo pip install tensorflow==2.0.0-rc0
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install pytest==5.0.1 codecov pytest-cov
- run: sudo pip install tensorboardX scikit-learn
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
- run: python -m pytest -sv ./examples/
- run: codecov
build_py2_torch:
working_directory: ~/pytorch-transformers
@ -40,7 +39,7 @@ jobs:
- checkout
- run: sudo pip install torch
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install pytest==5.0.1 codecov pytest-cov
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
- run: codecov
build_py2_tf:
@ -53,7 +52,7 @@ jobs:
- checkout
- run: sudo pip install tensorflow==2.0.0-rc0
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install pytest==5.0.1 codecov pytest-cov
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
- run: codecov
deploy_doc:

View File

@ -43,11 +43,11 @@ from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CO
# Modeling
try:
import torch
torch_available = True # pylint: disable=invalid-name
_torch_available = True # pylint: disable=invalid-name
except ImportError:
torch_available = False # pylint: disable=invalid-name
_torch_available = False # pylint: disable=invalid-name
if torch_available:
if _torch_available:
logger.info("PyTorch version {} available.".format(torch.__version__))
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
@ -87,19 +87,26 @@ if torch_available:
# TensorFlow
try:
import tensorflow as tf
tf_available = True # pylint: disable=invalid-name
assert int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name
except ImportError:
tf_available = False # pylint: disable=invalid-name
_tf_available = False # pylint: disable=invalid-name
if tf_available:
if _tf_available:
logger.info("TensorFlow version {} available.".format(tf.__version__))
from .modeling_tf_utils import TFPreTrainedModel
from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertModel, TFBertForPreTraining,
TFBertForMaskedLM, TFBertForNextSentencePrediction, load_pt_weights_in_bert)
TFBertForMaskedLM, TFBertForNextSentencePrediction, load_bert_pt_weights_in_tf)
# Files and general utilities
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
cached_path, add_start_docstrings, add_end_docstrings,
WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME)
def is_torch_available():
return _torch_available
def is_tf_available():
return _tf_available

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""
""" Convert pytorch checkpoints to TensorFlow """
from __future__ import absolute_import
from __future__ import division
@ -21,19 +21,22 @@ from __future__ import print_function
import argparse
import tensorflow as tf
from pytorch_transformers import BertConfig, TFBertForPreTraining, load_pt_weights_in_bert
from pytorch_transformers import BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf
import logging
logging.basicConfig(level=logging.INFO)
def convert_bert_checkpoint_to_tf(pytorch_checkpoint_path, bert_config_file, tf_dump_path):
# Initialise TF model
config = BertConfig.from_json_file(bert_config_file)
print("Building TensorFlow model from configuration: {}".format(str(config)))
model = TFBertForPreTraining(config)
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path):
if model_type == 'bert':
# Initialise TF model
config = BertConfig.from_json_file(config_file)
print("Building TensorFlow model from configuration: {}".format(str(config)))
model = TFBertForPreTraining(config)
# Load weights from tf checkpoint
model = load_pt_weights_in_bert(model, config, pytorch_checkpoint_path)
# Load weights from tf checkpoint
model = load_bert_pt_weights_in_tf(model, config, pytorch_checkpoint_path)
else:
raise ValueError("Unrecognized model type, should be one of ['bert'].")
# Save pytorch-model
print("Save TensorFlow model to {}".format(tf_dump_path))
@ -43,16 +46,21 @@ def convert_bert_checkpoint_to_tf(pytorch_checkpoint_path, bert_config_file, tf_
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--model_type",
default = None,
type = str,
required = True,
help = "Model type selcted in the list of.")
parser.add_argument("--pytorch_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path to the PyTorch checkpoint path.")
parser.add_argument("--bert_config_file",
parser.add_argument("--config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained BERT model. \n"
help = "The config json file corresponding to the pre-trained model. \n"
"This specifies the model architecture.")
parser.add_argument("--tf_dump_path",
default = None,
@ -60,6 +68,7 @@ if __name__ == "__main__":
required = True,
help = "Path to the output Tensorflow dump file.")
args = parser.parse_args()
convert_bert_checkpoint_to_tf(args.pytorch_checkpoint_path,
args.bert_config_file,
args.tf_dump_path)
convert_pt_checkpoint_to_tf(args.model_type.lower(),
args.pytorch_checkpoint_path,
args.config_file,
args.tf_dump_path)

View File

@ -51,7 +51,7 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
}
def load_pt_weights_in_bert(tf_model, config, pytorch_checkpoint_path):
def load_bert_pt_weights_in_tf(tf_model, config, pytorch_checkpoint_path):
""" Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
@ -150,6 +150,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
@tf.function
def call(self, inputs, training=False):
input_ids, position_ids, token_type_ids = inputs
@ -194,6 +195,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3])
@tf.function
def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs
@ -242,6 +244,7 @@ class TFBertSelfOutput(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
@tf.function
def call(self, inputs, training=False):
hidden_states, input_tensor = inputs
@ -261,6 +264,7 @@ class TFBertAttention(tf.keras.layers.Layer):
def prune_heads(self, heads):
raise NotImplementedError
@tf.function
def call(self, inputs, training=False):
input_tensor, attention_mask, head_mask = inputs
@ -279,6 +283,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
else:
self.intermediate_act_fn = config.hidden_act
@tf.function
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
@ -292,6 +297,7 @@ class TFBertOutput(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
@tf.function
def call(self, inputs, training=False):
hidden_states, input_tensor = inputs
@ -309,6 +315,7 @@ class TFBertLayer(tf.keras.layers.Layer):
self.intermediate = TFBertIntermediate(config, name='intermediate')
self.bert_output = TFBertOutput(config, name='output')
@tf.function
def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs
@ -327,6 +334,7 @@ class TFBertEncoder(tf.keras.layers.Layer):
self.output_hidden_states = config.output_hidden_states
self.layer = [TFBertLayer(config, name='layer_{}'.format(i)) for i in range(config.num_hidden_layers)]
@tf.function
def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs
@ -359,6 +367,7 @@ class TFBertPooler(tf.keras.layers.Layer):
super(TFBertPooler, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.hidden_size, activation='tanh', name='dense')
@tf.function
def call(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
@ -377,6 +386,7 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
self.transform_act_fn = config.hidden_act
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
@tf.function
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
@ -400,6 +410,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
trainable=True,
name='bias')
@tf.function
def call(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states) + self.bias
@ -411,6 +422,7 @@ class TFBertMLMHead(tf.keras.layers.Layer):
super(TFBertMLMHead, self).__init__(**kwargs)
self.predictions = TFBertLMPredictionHead(config, name='predictions')
@tf.function
def call(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
@ -421,6 +433,7 @@ class TFBertNSPHead(tf.keras.layers.Layer):
super(TFBertNSPHead, self).__init__(**kwargs)
self.seq_relationship = tf.keras.layers.Dense(2, name='seq_relationship')
@tf.function
def call(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
@ -447,6 +460,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
"""
raise NotImplementedError
@tf.function
def call(self, inputs, training=False):
if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs
@ -459,12 +473,12 @@ class TFBertMainLayer(tf.keras.layers.Layer):
head_mask = inputs[4] if len(inputs) > 4 else None
assert len(inputs) <= 5, "Too many inputs."
else:
input_ids = inputs.pop('input_ids')
attention_mask = inputs.pop('attention_mask', None)
token_type_ids = inputs.pop('token_type_ids', None)
position_ids = inputs.pop('position_ids', None)
head_mask = inputs.pop('head_mask', None)
assert len(inputs) == 0, "Unexpected inputs detected: {}. Check inputs dict key names.".format(list(inputs.keys()))
input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', None)
token_type_ids = inputs.get('token_type_ids', None)
position_ids = inputs.get('position_ids', None)
head_mask = inputs.get('head_mask', None)
assert len(inputs) <= 5, "Too many inputs."
if attention_mask is None:
attention_mask = tf.fill(tf.shape(input_ids), 1)
@ -507,23 +521,16 @@ class TFBertMainLayer(tf.keras.layers.Layer):
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
class TFBertPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class = BertConfig
pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights = load_pt_weights_in_bert
load_pt_weights = load_bert_pt_weights_in_tf
base_model_prefix = "bert"
def __init__(self, *inputs, **kwargs):
super(TFBertPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module):
""" Initialize the weights.
"""
raise NotImplementedError
BERT_START_DOCSTRING = r""" The BERT model was proposed in
`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_
@ -635,6 +642,7 @@ class TFBertModel(TFBertPreTrainedModel):
super(TFBertModel, self).__init__(config)
self.bert = TFBertMainLayer(config, name='bert')
@tf.function
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
return outputs
@ -687,7 +695,6 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
self.cls_mlm = TFBertMLMHead(config, name='cls_mlm')
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
# self.apply(self.init_weights) # TODO check added weights initialization
self.tie_weights()
def tie_weights(self):
@ -695,6 +702,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
"""
pass # TODO add weights tying
@tf.function
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
@ -704,14 +712,6 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
# if masked_lm_labels is not None and next_sentence_label is not None:
# loss_fct = CrossEntropyLoss(ignore_index=-1)
# masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
# next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
# total_loss = masked_lm_loss + next_sentence_loss
# outputs = (total_loss,) + outputs
# TODO add example with losses using model.compile and a dictionary of losses (give names to the output layers)
return outputs # prediction_scores, seq_relationship_score, (hidden_states), (attentions)
@ -753,7 +753,6 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert')
self.cls_mlm = TFBertMLMHead(config, name='cls_mlm')
# self.apply(self.init_weights)
self.tie_weights()
def tie_weights(self):
@ -761,6 +760,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
"""
pass # TODO add weights tying
@tf.function
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
@ -768,11 +768,6 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
prediction_scores = self.cls_mlm(sequence_output)
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
# if masked_lm_labels is not None:
# loss_fct = CrossEntropyLoss(ignore_index=-1)
# masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
# outputs = (masked_lm_loss,) + outputs
# TODO example with losses
return outputs # prediction_scores, (hidden_states), (attentions)
@ -815,8 +810,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert')
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
# self.apply(self.init_weights)
@tf.function
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
@ -824,9 +818,299 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
seq_relationship_score = self.cls_nsp(pooled_output)
outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
# if next_sentence_label is not None:
# loss_fct = CrossEntropyLoss(ignore_index=-1)
# next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
# outputs = (next_sentence_loss,) + outputs
return outputs # seq_relationship_score, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class TFBertForSequenceClassification(TFBertPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification (or regression if config.num_labels==1) loss.
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
def __init__(self, config):
super(TFBertForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name='bert')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier')
@tf.function
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
return outputs # logits, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
BERT_START_DOCSTRING)
class TFBertForMultipleChoice(TFBertPreTrainedModel):
r"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:
(a) For sequence pairs:
``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
(b) For single sequences:
``tokens: [CLS] the dog is hairy . [SEP]``
``token_type_ids: 0 0 0 0 0 0 0``
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Segment token indices to indicate first and second portions of the inputs.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
(see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Mask to avoid performing attention on padding token indices.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, classification_scores = outputs[:2]
"""
def __init__(self, config):
super(TFBertForMultipleChoice, self).__init__(config)
self.bert = TFBertMainLayer(config, name='bert')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(1, name='classifier')
@tf.function
def call(self, inputs, training=False):
if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs
attention_mask, head_mask, position_ids, token_type_ids = None, None, None, None
elif isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else None
token_type_ids = inputs[2] if len(inputs) > 2 else None
position_ids = inputs[3] if len(inputs) > 3 else None
head_mask = inputs[4] if len(inputs) > 4 else None
assert len(inputs) <= 5, "Too many inputs."
else:
input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', None)
token_type_ids = inputs.get('token_type_ids', None)
position_ids = inputs.get('position_ids', None)
head_mask = inputs.get('head_mask', None)
assert len(inputs) <= 5, "Too many inputs."
num_choices = tf.shape(input_ids)[1]
seq_length = tf.shape(input_ids)[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length))
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask]
outputs = self.bert(flat_inputs, training=training)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
return outputs # reshaped_logits, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class TFBertForTokenClassification(TFBertPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForTokenClassification.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
"""
def __init__(self, config):
super(TFBertForTokenClassification, self).__init__(config)
self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name='bert')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier')
@tf.function
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
return outputs # scores, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class TFBertForQuestionAnswering(TFBertPreTrainedModel):
r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-start scores (before SoftMax).
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
start_positions = torch.tensor([1])
end_positions = torch.tensor([3])
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
loss, start_scores, end_scores = outputs[:2]
"""
def __init__(self, config):
super(TFBertForQuestionAnswering, self).__init__(config)
self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name='bert')
self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs')
@tf.function
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = tf.split(logits, 2, axis=-1)
start_logits = tf.squeeze(start_logits, axis=-1)
end_logits = tf.squeeze(end_logits, axis=-1)
outputs = (start_logits, end_logits,) + outputs[2:]
return outputs # start_logits, end_logits, (hidden_states), (attentions)

View File

@ -21,15 +21,18 @@ import shutil
import pytest
import logging
from pytorch_transformers import (AutoConfig, BertConfig,
AutoModel, BertModel,
AutoModelWithLMHead, BertForMaskedLM,
AutoModelForSequenceClassification, BertForSequenceClassification,
AutoModelForQuestionAnswering, BertForQuestionAnswering)
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
try:
from pytorch_transformers import (AutoConfig, BertConfig,
AutoModel, BertModel,
AutoModelWithLMHead, BertForMaskedLM,
AutoModelForSequenceClassification, BertForSequenceClassification,
AutoModelForQuestionAnswering, BertForQuestionAnswering)
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
except ImportError:
pytestmark = pytest.mark.skip("Require Torch")
class AutoModelTest(unittest.TestCase):

View File

@ -20,21 +20,26 @@ import unittest
import shutil
import pytest
from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification, BertForMultipleChoice)
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
try:
from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification, BertForMultipleChoice)
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
except ImportError:
pytestmark = pytest.mark.skip("Require Torch")
class BertModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification)
BertForTokenClassification) if is_torch_available() else ()
class BertModelTester(object):

View File

@ -25,12 +25,16 @@ import uuid
import unittest
import logging
import pytest
import torch
try:
import torch
from pytorch_transformers import (PretrainedConfig, PreTrainedModel,
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
from pytorch_transformers import (PretrainedConfig, PreTrainedModel,
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
except ImportError:
pytestmark = pytest.mark.skip("Require Torch")
def _config_zero_init(config):

View File

@ -17,9 +17,15 @@ from __future__ import division
from __future__ import print_function
import unittest
import pytest
from pytorch_transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM,
DistilBertForQuestionAnswering, DistilBertForSequenceClassification)
from pytorch_transformers import is_torch_available
try:
from pytorch_transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM,
DistilBertForQuestionAnswering, DistilBertForSequenceClassification)
except ImportError:
pytestmark = pytest.mark.skip("Require Torch")
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
@ -28,7 +34,7 @@ from .configuration_common_test import ConfigTester
class DistilBertModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (DistilBertModel, DistilBertForMaskedLM, DistilBertForQuestionAnswering,
DistilBertForSequenceClassification)
DistilBertForSequenceClassification) if is_torch_available() else None
test_pruning = True
test_torchscript = True
test_resize_embeddings = True

View File

@ -20,9 +20,13 @@ import unittest
import pytest
import shutil
from pytorch_transformers import is_torch_available
from pytorch_transformers import (GPT2Config, GPT2Model, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2DoubleHeadsModel)
try:
from pytorch_transformers import (GPT2Config, GPT2Model, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2DoubleHeadsModel)
except ImportError:
pytestmark = pytest.mark.skip("Require Torch")
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
@ -30,7 +34,7 @@ from .configuration_common_test import ConfigTester
class GPT2ModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel)
all_model_classes = (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
class GPT2ModelTester(object):

View File

@ -20,9 +20,13 @@ import unittest
import pytest
import shutil
from pytorch_transformers import is_torch_available
from pytorch_transformers import (OpenAIGPTConfig, OpenAIGPTModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
try:
from pytorch_transformers import (OpenAIGPTConfig, OpenAIGPTModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
except ImportError:
pytestmark = pytest.mark.skip("Require Torch")
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
@ -30,7 +34,7 @@ from .configuration_common_test import ConfigTester
class OpenAIGPTModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
all_model_classes = (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) if is_torch_available() else ()
class OpenAIGPTModelTester(object):

View File

@ -19,10 +19,15 @@ from __future__ import print_function
import unittest
import shutil
import pytest
import torch
from pytorch_transformers import (RobertaConfig, RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification)
from pytorch_transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers import is_torch_available
try:
import torch
from pytorch_transformers import (RobertaConfig, RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification)
from pytorch_transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
except ImportError:
pytestmark = pytest.mark.skip("Require Torch")
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
@ -30,7 +35,7 @@ from .configuration_common_test import ConfigTester
class RobertaModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (RobertaForMaskedLM, RobertaModel)
all_model_classes = (RobertaForMaskedLM, RobertaModel) if is_torch_available() else ()
class RobertaModelTester(object):

View File

@ -24,21 +24,27 @@ import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from pytorch_transformers import BertConfig, is_tf_available
try:
import tensorflow as tf
from pytorch_transformers import (BertConfig)
from pytorch_transformers.modeling_tf_bert import TFBertModel, TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers.modeling_tf_bert import (TFBertModel, TFBertForMaskedLM,
TFBertForNextSentencePrediction,
TFBertForPreTraining,
TFBertForSequenceClassification,
TFBertForMultipleChoice,
TFBertForTokenClassification,
TFBertForQuestionAnswering,
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
except ImportError:
pass
pytestmark = pytest.mark.skip("Require TensorFlow")
class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes = (TFBertModel,)
# BertForMaskedLM, BertForNextSentencePrediction,
# BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
# BertForTokenClassification)
all_model_classes = (TFBertModel, TFBertForMaskedLM, TFBertForNextSentencePrediction,
TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification,
TFBertForTokenClassification) if is_tf_available() else ()
class TFBertModelTester(object):
@ -123,14 +129,8 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result):
self.parent.assertListEqual(
list(result["loss"].size()),
[])
def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = TFBertModel(config=config)
# model.eval()
inputs = {'input_ids': input_ids,
'attention_mask': input_mask,
'token_type_ids': token_type_ids}
@ -152,125 +152,115 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
def create_and_check_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
pass
# model = BertForMaskedLM(config=config)
# model.eval()
# loss, prediction_scores = model(input_ids, token_type_ids, input_mask, token_labels)
# result = {
# "loss": loss,
# "prediction_scores": prediction_scores,
# }
# self.parent.assertListEqual(
# list(result["prediction_scores"].size()),
# [self.batch_size, self.seq_length, self.vocab_size])
# self.check_loss_output(result)
model = TFBertForMaskedLM(config=config)
inputs = {'input_ids': input_ids,
'attention_mask': input_mask,
'token_type_ids': token_type_ids}
prediction_scores, = model(inputs)
result = {
"prediction_scores": prediction_scores.numpy(),
}
self.parent.assertListEqual(
list(result["prediction_scores"].shape),
[self.batch_size, self.seq_length, self.vocab_size])
def create_and_check_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
pass
# model = BertForNextSentencePrediction(config=config)
# model.eval()
# loss, seq_relationship_score = model(input_ids, token_type_ids, input_mask, sequence_labels)
# result = {
# "loss": loss,
# "seq_relationship_score": seq_relationship_score,
# }
# self.parent.assertListEqual(
# list(result["seq_relationship_score"].size()),
# [self.batch_size, 2])
# self.check_loss_output(result)
model = TFBertForNextSentencePrediction(config=config)
inputs = {'input_ids': input_ids,
'attention_mask': input_mask,
'token_type_ids': token_type_ids}
seq_relationship_score, = model(inputs)
result = {
"seq_relationship_score": seq_relationship_score.numpy(),
}
self.parent.assertListEqual(
list(result["seq_relationship_score"].shape),
[self.batch_size, 2])
def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
pass
# model = BertForPreTraining(config=config)
# model.eval()
# loss, prediction_scores, seq_relationship_score = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
# result = {
# "loss": loss,
# "prediction_scores": prediction_scores,
# "seq_relationship_score": seq_relationship_score,
# }
# self.parent.assertListEqual(
# list(result["prediction_scores"].size()),
# [self.batch_size, self.seq_length, self.vocab_size])
# self.parent.assertListEqual(
# list(result["seq_relationship_score"].size()),
# [self.batch_size, 2])
# self.check_loss_output(result)
def create_and_check_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
pass
# model = BertForQuestionAnswering(config=config)
# model.eval()
# loss, start_logits, end_logits = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
# result = {
# "loss": loss,
# "start_logits": start_logits,
# "end_logits": end_logits,
# }
# self.parent.assertListEqual(
# list(result["start_logits"].size()),
# [self.batch_size, self.seq_length])
# self.parent.assertListEqual(
# list(result["end_logits"].size()),
# [self.batch_size, self.seq_length])
# self.check_loss_output(result)
model = TFBertForPreTraining(config=config)
inputs = {'input_ids': input_ids,
'attention_mask': input_mask,
'token_type_ids': token_type_ids}
prediction_scores, seq_relationship_score = model(inputs)
result = {
"prediction_scores": prediction_scores.numpy(),
"seq_relationship_score": seq_relationship_score.numpy(),
}
self.parent.assertListEqual(
list(result["prediction_scores"].shape),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(result["seq_relationship_score"].shape),
[self.batch_size, 2])
def create_and_check_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
pass
# config.num_labels = self.num_labels
# model = BertForSequenceClassification(config)
# model.eval()
# loss, logits = model(input_ids, token_type_ids, input_mask, sequence_labels)
# result = {
# "loss": loss,
# "logits": logits,
# }
# self.parent.assertListEqual(
# list(result["logits"].size()),
# [self.batch_size, self.num_labels])
# self.check_loss_output(result)
def create_and_check_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
pass
# config.num_labels = self.num_labels
# model = BertForTokenClassification(config=config)
# model.eval()
# loss, logits = model(input_ids, token_type_ids, input_mask, token_labels)
# result = {
# "loss": loss,
# "logits": logits,
# }
# self.parent.assertListEqual(
# list(result["logits"].size()),
# [self.batch_size, self.seq_length, self.num_labels])
# self.check_loss_output(result)
config.num_labels = self.num_labels
model = TFBertForSequenceClassification(config=config)
inputs = {'input_ids': input_ids,
'attention_mask': input_mask,
'token_type_ids': token_type_ids}
logits, = model(inputs)
result = {
"logits": logits.numpy(),
}
self.parent.assertListEqual(
list(result["logits"].shape),
[self.batch_size, self.num_labels])
def create_and_check_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
pass
# config.num_choices = self.num_choices
# model = BertForMultipleChoice(config=config)
# model.eval()
# multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
# multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
# multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
# loss, logits = model(multiple_choice_inputs_ids,
# multiple_choice_token_type_ids,
# multiple_choice_input_mask,
# choice_labels)
# result = {
# "loss": loss,
# "logits": logits,
# }
# self.parent.assertListEqual(
# list(result["logits"].size()),
# [self.batch_size, self.num_choices])
# self.check_loss_output(result)
config.num_choices = self.num_choices
model = TFBertForMultipleChoice(config=config)
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
inputs = {'input_ids': multiple_choice_inputs_ids,
'attention_mask': multiple_choice_input_mask,
'token_type_ids': multiple_choice_token_type_ids}
logits, = model(inputs)
result = {
"logits": logits.numpy(),
}
self.parent.assertListEqual(
list(result["logits"].shape),
[self.batch_size, self.num_choices])
def create_and_check_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
config.num_labels = self.num_labels
model = TFBertForTokenClassification(config=config)
inputs = {'input_ids': input_ids,
'attention_mask': input_mask,
'token_type_ids': token_type_ids}
logits, = model(inputs)
result = {
"logits": logits.numpy(),
}
self.parent.assertListEqual(
list(result["logits"].shape),
[self.batch_size, self.seq_length, self.num_labels])
def create_and_check_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = TFBertForQuestionAnswering(config=config)
inputs = {'input_ids': input_ids,
'attention_mask': input_mask,
'token_type_ids': token_type_ids}
start_logits, end_logits = model(inputs)
result = {
"start_logits": start_logits.numpy(),
"end_logits": end_logits.numpy(),
}
self.parent.assertListEqual(
list(result["start_logits"].shape),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["end_logits"].shape),
[self.batch_size, self.seq_length])
def prepare_config_and_inputs_for_common(self):
@ -287,48 +277,39 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
def test_config(self):
self.config_tester.run_common_tests()
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_bert_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_model(*config_and_inputs)
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_for_next_sequence_prediction(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs)
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_for_pretraining(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs)
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_question_answering(*config_and_inputs)
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_sequence_classification(*config_and_inputs)
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs)
@pytest.mark.slow
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:

View File

@ -30,7 +30,7 @@ try:
from pytorch_transformers import TFPreTrainedModel
# from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
except ImportError:
pass
pytestmark = pytest.mark.skip("Require TensorFlow")
def _config_zero_init(config):
@ -50,7 +50,6 @@ class TFCommonTestCases:
test_pruning = True
test_resize_embeddings = True
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_initialization(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -64,7 +63,6 @@ class TFCommonTestCases:
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_attention_outputs(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -105,7 +103,6 @@ class TFCommonTestCases:
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_headmasking(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -153,7 +150,6 @@ class TFCommonTestCases:
# attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_head_pruning(self):
pass
# if not self.test_pruning:
@ -181,7 +177,6 @@ class TFCommonTestCases:
# attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_hidden_states_output(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -201,7 +196,6 @@ class TFCommonTestCases:
# [self.model_tester.seq_length, self.model_tester.hidden_size])
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_resize_tokens_embeddings(self):
pass
# original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -238,7 +232,6 @@ class TFCommonTestCases:
# self.assertTrue(models_equal)
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_tie_model_weights(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -21,17 +21,21 @@ import random
import shutil
import pytest
import torch
from pytorch_transformers import is_torch_available
from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
from pytorch_transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
try:
import torch
from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
from pytorch_transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
except ImportError:
pytestmark = pytest.mark.skip("Require Torch")
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
class TransfoXLModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel)
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else ()
test_pruning = False
test_torchscript = False
test_resize_embeddings = False

View File

@ -20,8 +20,14 @@ import unittest
import shutil
import pytest
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification)
from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers import is_torch_available
try:
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering,
XLMForSequenceClassification)
from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
except ImportError:
pytestmark = pytest.mark.skip("Require Torch")
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
@ -29,9 +35,9 @@ from .configuration_common_test import ConfigTester
class XLMModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (XLMModel, XLMWithLMHeadModel,
XLMForQuestionAnswering, XLMForSequenceClassification)
# , XLMForSequenceClassification, XLMForTokenClassification),
all_model_classes = (XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering,
XLMForSequenceClassification) if is_torch_available() else ()
class XLMModelTester(object):

View File

@ -23,10 +23,15 @@ import random
import shutil
import pytest
import torch
from pytorch_transformers import is_torch_available
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
from pytorch_transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
try:
import torch
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
from pytorch_transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
except ImportError:
pytestmark = pytest.mark.skip("Require Torch")
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
@ -34,7 +39,7 @@ from .configuration_common_test import ConfigTester
class XLNetModelTest(CommonTestCases.CommonModelTester):
all_model_classes=(XLNetModel, XLNetLMHeadModel,
XLNetForSequenceClassification, XLNetForQuestionAnswering)
XLNetForSequenceClassification, XLNetForQuestionAnswering) if is_torch_available() else ()
test_pruning = False
class XLNetModelTester(object):

View File

@ -18,11 +18,17 @@ from __future__ import print_function
import unittest
import os
import pytest
import torch
from pytorch_transformers import is_torch_available
from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
try:
import torch
from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
except ImportError:
pytestmark = pytest.mark.skip("Require Torch")
from .tokenization_tests_commons import TemporaryDirectory
@ -71,8 +77,8 @@ class OptimizationTest(unittest.TestCase):
class ScheduleInitTest(unittest.TestCase):
m = torch.nn.Linear(50, 50)
optimizer = AdamW(m.parameters(), lr=10.)
m = torch.nn.Linear(50, 50) if is_torch_available() else None
optimizer = AdamW(m.parameters(), lr=10.) if is_torch_available() else None
num_steps = 10
def assertListAlmostEqual(self, list1, list2, tol):

View File

@ -22,20 +22,19 @@ import pytest
import logging
from pytorch_transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
class AutoTokenizerTest(unittest.TestCase):
def test_tokenizer_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
self.assertIsNotNone(tokenizer)
self.assertIsInstance(tokenizer, BertTokenizer)
self.assertGreater(len(tokenizer), 0)
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in list(GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
self.assertIsNotNone(tokenizer)
self.assertIsInstance(tokenizer, GPT2Tokenizer)

View File

@ -16,15 +16,21 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import unittest
import pytest
from io import open
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
from pytorch_transformers import is_torch_available
from.tokenization_tests_commons import CommonTestCases
try:
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
except ImportError:
pytestmark = pytest.mark.skip("Require Torch") # TODO: untangle Transfo-XL tokenizer from torch.load and torch.save
from .tokenization_tests_commons import CommonTestCases
class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = TransfoXLTokenizer
tokenizer_class = TransfoXLTokenizer if is_torch_available() else None
def setUp(self):
super(TransfoXLTokenizationTest, self).setUp()

View File

@ -26,16 +26,20 @@ import sys
from collections import Counter, OrderedDict
from io import open
import torch
import numpy as np
from .file_utils import cached_path
from .tokenization_utils import PreTrainedTokenizer
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
try:
import torch
except ImportError:
pass
# if sys.version_info[0] == 2:
# import cPickle as pickle
# else:
# import pickle
logger = logging.getLogger(__name__)