mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
test suite independent of framework
This commit is contained in:
parent
9d0a11a68c
commit
518307dfcd
@ -10,7 +10,7 @@ jobs:
|
||||
- checkout
|
||||
- run: sudo pip install torch
|
||||
- run: sudo pip install --progress-bar off .
|
||||
- run: sudo pip install pytest codecov pytest-cov
|
||||
- run: sudo pip install pytest==5.0.1 codecov pytest-cov
|
||||
- run: sudo pip install tensorboardX scikit-learn
|
||||
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
|
||||
- run: python -m pytest -sv ./examples/
|
||||
@ -25,10 +25,9 @@ jobs:
|
||||
- checkout
|
||||
- run: sudo pip install tensorflow==2.0.0-rc0
|
||||
- run: sudo pip install --progress-bar off .
|
||||
- run: sudo pip install pytest codecov pytest-cov
|
||||
- run: sudo pip install pytest==5.0.1 codecov pytest-cov
|
||||
- run: sudo pip install tensorboardX scikit-learn
|
||||
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
|
||||
- run: python -m pytest -sv ./examples/
|
||||
- run: codecov
|
||||
build_py2_torch:
|
||||
working_directory: ~/pytorch-transformers
|
||||
@ -40,7 +39,7 @@ jobs:
|
||||
- checkout
|
||||
- run: sudo pip install torch
|
||||
- run: sudo pip install --progress-bar off .
|
||||
- run: sudo pip install pytest codecov pytest-cov
|
||||
- run: sudo pip install pytest==5.0.1 codecov pytest-cov
|
||||
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
|
||||
- run: codecov
|
||||
build_py2_tf:
|
||||
@ -53,7 +52,7 @@ jobs:
|
||||
- checkout
|
||||
- run: sudo pip install tensorflow==2.0.0-rc0
|
||||
- run: sudo pip install --progress-bar off .
|
||||
- run: sudo pip install pytest codecov pytest-cov
|
||||
- run: sudo pip install pytest==5.0.1 codecov pytest-cov
|
||||
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
|
||||
- run: codecov
|
||||
deploy_doc:
|
||||
|
@ -43,11 +43,11 @@ from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CO
|
||||
# Modeling
|
||||
try:
|
||||
import torch
|
||||
torch_available = True # pylint: disable=invalid-name
|
||||
_torch_available = True # pylint: disable=invalid-name
|
||||
except ImportError:
|
||||
torch_available = False # pylint: disable=invalid-name
|
||||
_torch_available = False # pylint: disable=invalid-name
|
||||
|
||||
if torch_available:
|
||||
if _torch_available:
|
||||
logger.info("PyTorch version {} available.".format(torch.__version__))
|
||||
|
||||
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
|
||||
@ -87,19 +87,26 @@ if torch_available:
|
||||
# TensorFlow
|
||||
try:
|
||||
import tensorflow as tf
|
||||
tf_available = True # pylint: disable=invalid-name
|
||||
assert int(tf.__version__[0]) >= 2
|
||||
_tf_available = True # pylint: disable=invalid-name
|
||||
except ImportError:
|
||||
tf_available = False # pylint: disable=invalid-name
|
||||
_tf_available = False # pylint: disable=invalid-name
|
||||
|
||||
if tf_available:
|
||||
if _tf_available:
|
||||
logger.info("TensorFlow version {} available.".format(tf.__version__))
|
||||
|
||||
from .modeling_tf_utils import TFPreTrainedModel
|
||||
from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertModel, TFBertForPreTraining,
|
||||
TFBertForMaskedLM, TFBertForNextSentencePrediction, load_pt_weights_in_bert)
|
||||
TFBertForMaskedLM, TFBertForNextSentencePrediction, load_bert_pt_weights_in_tf)
|
||||
|
||||
|
||||
# Files and general utilities
|
||||
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
||||
cached_path, add_start_docstrings, add_end_docstrings,
|
||||
WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME)
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
def is_tf_available():
|
||||
return _tf_available
|
||||
|
@ -12,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BERT checkpoint."""
|
||||
""" Convert pytorch checkpoints to TensorFlow """
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -21,19 +21,22 @@ from __future__ import print_function
|
||||
import argparse
|
||||
import tensorflow as tf
|
||||
|
||||
from pytorch_transformers import BertConfig, TFBertForPreTraining, load_pt_weights_in_bert
|
||||
from pytorch_transformers import BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf
|
||||
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
def convert_bert_checkpoint_to_tf(pytorch_checkpoint_path, bert_config_file, tf_dump_path):
|
||||
# Initialise TF model
|
||||
config = BertConfig.from_json_file(bert_config_file)
|
||||
print("Building TensorFlow model from configuration: {}".format(str(config)))
|
||||
model = TFBertForPreTraining(config)
|
||||
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path):
|
||||
if model_type == 'bert':
|
||||
# Initialise TF model
|
||||
config = BertConfig.from_json_file(config_file)
|
||||
print("Building TensorFlow model from configuration: {}".format(str(config)))
|
||||
model = TFBertForPreTraining(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
model = load_pt_weights_in_bert(model, config, pytorch_checkpoint_path)
|
||||
# Load weights from tf checkpoint
|
||||
model = load_bert_pt_weights_in_tf(model, config, pytorch_checkpoint_path)
|
||||
else:
|
||||
raise ValueError("Unrecognized model type, should be one of ['bert'].")
|
||||
|
||||
# Save pytorch-model
|
||||
print("Save TensorFlow model to {}".format(tf_dump_path))
|
||||
@ -43,16 +46,21 @@ def convert_bert_checkpoint_to_tf(pytorch_checkpoint_path, bert_config_file, tf_
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--model_type",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Model type selcted in the list of.")
|
||||
parser.add_argument("--pytorch_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the PyTorch checkpoint path.")
|
||||
parser.add_argument("--bert_config_file",
|
||||
parser.add_argument("--config_file",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "The config json file corresponding to the pre-trained BERT model. \n"
|
||||
help = "The config json file corresponding to the pre-trained model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--tf_dump_path",
|
||||
default = None,
|
||||
@ -60,6 +68,7 @@ if __name__ == "__main__":
|
||||
required = True,
|
||||
help = "Path to the output Tensorflow dump file.")
|
||||
args = parser.parse_args()
|
||||
convert_bert_checkpoint_to_tf(args.pytorch_checkpoint_path,
|
||||
args.bert_config_file,
|
||||
args.tf_dump_path)
|
||||
convert_pt_checkpoint_to_tf(args.model_type.lower(),
|
||||
args.pytorch_checkpoint_path,
|
||||
args.config_file,
|
||||
args.tf_dump_path)
|
@ -51,7 +51,7 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def load_pt_weights_in_bert(tf_model, config, pytorch_checkpoint_path):
|
||||
def load_bert_pt_weights_in_tf(tf_model, config, pytorch_checkpoint_path):
|
||||
""" Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format
|
||||
We use HDF5 to easily do transfer learning
|
||||
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
|
||||
@ -150,6 +150,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
input_ids, position_ids, token_type_ids = inputs
|
||||
|
||||
@ -194,6 +195,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
|
||||
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||
return tf.transpose(x, perm=[0, 2, 1, 3])
|
||||
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, attention_mask, head_mask = inputs
|
||||
|
||||
@ -242,6 +244,7 @@ class TFBertSelfOutput(tf.keras.layers.Layer):
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, input_tensor = inputs
|
||||
|
||||
@ -261,6 +264,7 @@ class TFBertAttention(tf.keras.layers.Layer):
|
||||
def prune_heads(self, heads):
|
||||
raise NotImplementedError
|
||||
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
input_tensor, attention_mask, head_mask = inputs
|
||||
|
||||
@ -279,6 +283,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
@tf.function
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
@ -292,6 +297,7 @@ class TFBertOutput(tf.keras.layers.Layer):
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, input_tensor = inputs
|
||||
|
||||
@ -309,6 +315,7 @@ class TFBertLayer(tf.keras.layers.Layer):
|
||||
self.intermediate = TFBertIntermediate(config, name='intermediate')
|
||||
self.bert_output = TFBertOutput(config, name='output')
|
||||
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, attention_mask, head_mask = inputs
|
||||
|
||||
@ -327,6 +334,7 @@ class TFBertEncoder(tf.keras.layers.Layer):
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.layer = [TFBertLayer(config, name='layer_{}'.format(i)) for i in range(config.num_hidden_layers)]
|
||||
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, attention_mask, head_mask = inputs
|
||||
|
||||
@ -359,6 +367,7 @@ class TFBertPooler(tf.keras.layers.Layer):
|
||||
super(TFBertPooler, self).__init__(**kwargs)
|
||||
self.dense = tf.keras.layers.Dense(config.hidden_size, activation='tanh', name='dense')
|
||||
|
||||
@tf.function
|
||||
def call(self, hidden_states):
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
@ -377,6 +386,7 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
|
||||
self.transform_act_fn = config.hidden_act
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
|
||||
|
||||
@tf.function
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.transform_act_fn(hidden_states)
|
||||
@ -400,6 +410,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
|
||||
trainable=True,
|
||||
name='bias')
|
||||
|
||||
@tf.function
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states) + self.bias
|
||||
@ -411,6 +422,7 @@ class TFBertMLMHead(tf.keras.layers.Layer):
|
||||
super(TFBertMLMHead, self).__init__(**kwargs)
|
||||
self.predictions = TFBertLMPredictionHead(config, name='predictions')
|
||||
|
||||
@tf.function
|
||||
def call(self, sequence_output):
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
return prediction_scores
|
||||
@ -421,6 +433,7 @@ class TFBertNSPHead(tf.keras.layers.Layer):
|
||||
super(TFBertNSPHead, self).__init__(**kwargs)
|
||||
self.seq_relationship = tf.keras.layers.Dense(2, name='seq_relationship')
|
||||
|
||||
@tf.function
|
||||
def call(self, pooled_output):
|
||||
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||
return seq_relationship_score
|
||||
@ -447,6 +460,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
if not isinstance(inputs, (dict, tuple, list)):
|
||||
input_ids = inputs
|
||||
@ -459,12 +473,12 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
head_mask = inputs[4] if len(inputs) > 4 else None
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs.pop('input_ids')
|
||||
attention_mask = inputs.pop('attention_mask', None)
|
||||
token_type_ids = inputs.pop('token_type_ids', None)
|
||||
position_ids = inputs.pop('position_ids', None)
|
||||
head_mask = inputs.pop('head_mask', None)
|
||||
assert len(inputs) == 0, "Unexpected inputs detected: {}. Check inputs dict key names.".format(list(inputs.keys()))
|
||||
input_ids = inputs.get('input_ids')
|
||||
attention_mask = inputs.get('attention_mask', None)
|
||||
token_type_ids = inputs.get('token_type_ids', None)
|
||||
position_ids = inputs.get('position_ids', None)
|
||||
head_mask = inputs.get('head_mask', None)
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(tf.shape(input_ids), 1)
|
||||
@ -507,23 +521,16 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
||||
|
||||
|
||||
class TFBertPreTrainedModel(TFPreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
config_class = BertConfig
|
||||
pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_pt_weights = load_pt_weights_in_bert
|
||||
load_pt_weights = load_bert_pt_weights_in_tf
|
||||
base_model_prefix = "bert"
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super(TFBertPreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||
|
||||
def init_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
BERT_START_DOCSTRING = r""" The BERT model was proposed in
|
||||
`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_
|
||||
@ -635,6 +642,7 @@ class TFBertModel(TFBertPreTrainedModel):
|
||||
super(TFBertModel, self).__init__(config)
|
||||
self.bert = TFBertMainLayer(config, name='bert')
|
||||
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
outputs = self.bert(inputs, training=training)
|
||||
return outputs
|
||||
@ -687,7 +695,6 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
|
||||
self.cls_mlm = TFBertMLMHead(config, name='cls_mlm')
|
||||
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
|
||||
|
||||
# self.apply(self.init_weights) # TODO check added weights initialization
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
@ -695,6 +702,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
|
||||
"""
|
||||
pass # TODO add weights tying
|
||||
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
outputs = self.bert(inputs, training=training)
|
||||
|
||||
@ -704,14 +712,6 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
|
||||
|
||||
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
|
||||
|
||||
# if masked_lm_labels is not None and next_sentence_label is not None:
|
||||
# loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
# masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
||||
# next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
||||
# total_loss = masked_lm_loss + next_sentence_loss
|
||||
# outputs = (total_loss,) + outputs
|
||||
# TODO add example with losses using model.compile and a dictionary of losses (give names to the output layers)
|
||||
|
||||
return outputs # prediction_scores, seq_relationship_score, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@ -753,7 +753,6 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
|
||||
self.bert = TFBertMainLayer(config, name='bert')
|
||||
self.cls_mlm = TFBertMLMHead(config, name='cls_mlm')
|
||||
|
||||
# self.apply(self.init_weights)
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
@ -761,6 +760,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
|
||||
"""
|
||||
pass # TODO add weights tying
|
||||
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
outputs = self.bert(inputs, training=training)
|
||||
|
||||
@ -768,11 +768,6 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
|
||||
prediction_scores = self.cls_mlm(sequence_output)
|
||||
|
||||
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
||||
# if masked_lm_labels is not None:
|
||||
# loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
# masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
||||
# outputs = (masked_lm_loss,) + outputs
|
||||
# TODO example with losses
|
||||
|
||||
return outputs # prediction_scores, (hidden_states), (attentions)
|
||||
|
||||
@ -815,8 +810,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
|
||||
self.bert = TFBertMainLayer(config, name='bert')
|
||||
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
|
||||
|
||||
# self.apply(self.init_weights)
|
||||
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
outputs = self.bert(inputs, training=training)
|
||||
|
||||
@ -824,9 +818,299 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
|
||||
seq_relationship_score = self.cls_nsp(pooled_output)
|
||||
|
||||
outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
|
||||
# if next_sentence_label is not None:
|
||||
# loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
# next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
||||
# outputs = (next_sentence_loss,) + outputs
|
||||
|
||||
return outputs # seq_relationship_score, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
class TFBertForSequenceClassification(TFBertPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
||||
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(TFBertForSequenceClassification, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.bert = TFBertMainLayer(config, name='bert')
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier')
|
||||
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
outputs = self.bert(inputs, training=training)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
|
||||
return outputs # logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
|
||||
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||
BERT_START_DOCSTRING)
|
||||
class TFBertForMultipleChoice(TFBertPreTrainedModel):
|
||||
r"""
|
||||
Inputs:
|
||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
The second dimension of the input (`num_choices`) indicates the number of choices to score.
|
||||
To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:
|
||||
|
||||
(a) For sequence pairs:
|
||||
|
||||
``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
|
||||
|
||||
``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
|
||||
|
||||
(b) For single sequences:
|
||||
|
||||
``tokens: [CLS] the dog is hairy . [SEP]``
|
||||
|
||||
``token_type_ids: 0 0 0 0 0 0 0``
|
||||
|
||||
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
|
||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
The second dimension of the input (`num_choices`) indicates the number of choices to score.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
corresponds to a `sentence B` token
|
||||
(see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
|
||||
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
The second dimension of the input (`num_choices`) indicates the number of choices to score.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for computing the multiple choice classification loss.
|
||||
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
||||
of the input tensors. (see `input_ids` above)
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification loss.
|
||||
**classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
|
||||
of the input tensors. (see `input_ids` above).
|
||||
Classification scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
|
||||
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
|
||||
input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
|
||||
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
loss, classification_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(TFBertForMultipleChoice, self).__init__(config)
|
||||
|
||||
self.bert = TFBertMainLayer(config, name='bert')
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Dense(1, name='classifier')
|
||||
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
if not isinstance(inputs, (dict, tuple, list)):
|
||||
input_ids = inputs
|
||||
attention_mask, head_mask, position_ids, token_type_ids = None, None, None, None
|
||||
elif isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else None
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else None
|
||||
position_ids = inputs[3] if len(inputs) > 3 else None
|
||||
head_mask = inputs[4] if len(inputs) > 4 else None
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs.get('input_ids')
|
||||
attention_mask = inputs.get('attention_mask', None)
|
||||
token_type_ids = inputs.get('token_type_ids', None)
|
||||
position_ids = inputs.get('position_ids', None)
|
||||
head_mask = inputs.get('head_mask', None)
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
|
||||
num_choices = tf.shape(input_ids)[1]
|
||||
seq_length = tf.shape(input_ids)[2]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length))
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
|
||||
flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask]
|
||||
|
||||
outputs = self.bert(flat_inputs, training=training)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
|
||||
return outputs # reshaped_logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of
|
||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
class TFBertForTokenClassification(TFBertPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Labels for computing the token classification loss.
|
||||
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification loss.
|
||||
**scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
|
||||
Classification scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
model = BertForTokenClassification.from_pretrained('bert-base-uncased')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
loss, scores = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(TFBertForTokenClassification, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.bert = TFBertMainLayer(config, name='bert')
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier')
|
||||
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
outputs = self.bert(inputs, training=training)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
|
||||
return outputs # scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
class TFBertForQuestionAnswering(TFBertPreTrainedModel):
|
||||
r"""
|
||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
||||
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-start scores (before SoftMax).
|
||||
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-end scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
start_positions = torch.tensor([1])
|
||||
end_positions = torch.tensor([3])
|
||||
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
|
||||
loss, start_scores, end_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(TFBertForQuestionAnswering, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.bert = TFBertMainLayer(config, name='bert')
|
||||
self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs')
|
||||
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
outputs = self.bert(inputs, training=training)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = tf.split(logits, 2, axis=-1)
|
||||
start_logits = tf.squeeze(start_logits, axis=-1)
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
|
||||
outputs = (start_logits, end_logits,) + outputs[2:]
|
||||
|
||||
return outputs # start_logits, end_logits, (hidden_states), (attentions)
|
||||
|
@ -21,15 +21,18 @@ import shutil
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
from pytorch_transformers import (AutoConfig, BertConfig,
|
||||
AutoModel, BertModel,
|
||||
AutoModelWithLMHead, BertForMaskedLM,
|
||||
AutoModelForSequenceClassification, BertForSequenceClassification,
|
||||
AutoModelForQuestionAnswering, BertForQuestionAnswering)
|
||||
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
try:
|
||||
from pytorch_transformers import (AutoConfig, BertConfig,
|
||||
AutoModel, BertModel,
|
||||
AutoModelWithLMHead, BertForMaskedLM,
|
||||
AutoModelForSequenceClassification, BertForSequenceClassification,
|
||||
AutoModelForQuestionAnswering, BertForQuestionAnswering)
|
||||
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
except ImportError:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
|
||||
|
||||
class AutoModelTest(unittest.TestCase):
|
||||
|
@ -20,21 +20,26 @@ import unittest
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM,
|
||||
BertForNextSentencePrediction, BertForPreTraining,
|
||||
BertForQuestionAnswering, BertForSequenceClassification,
|
||||
BertForTokenClassification, BertForMultipleChoice)
|
||||
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from pytorch_transformers import is_torch_available
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
|
||||
try:
|
||||
from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM,
|
||||
BertForNextSentencePrediction, BertForPreTraining,
|
||||
BertForQuestionAnswering, BertForSequenceClassification,
|
||||
BertForTokenClassification, BertForMultipleChoice)
|
||||
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
except ImportError:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
|
||||
|
||||
class BertModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes = (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
|
||||
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
|
||||
BertForTokenClassification)
|
||||
BertForTokenClassification) if is_torch_available() else ()
|
||||
|
||||
class BertModelTester(object):
|
||||
|
||||
|
@ -25,12 +25,16 @@ import uuid
|
||||
|
||||
import unittest
|
||||
import logging
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
try:
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import (PretrainedConfig, PreTrainedModel,
|
||||
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from pytorch_transformers import (PretrainedConfig, PreTrainedModel,
|
||||
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
except ImportError:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
|
@ -17,9 +17,15 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM,
|
||||
DistilBertForQuestionAnswering, DistilBertForSequenceClassification)
|
||||
from pytorch_transformers import is_torch_available
|
||||
|
||||
try:
|
||||
from pytorch_transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM,
|
||||
DistilBertForQuestionAnswering, DistilBertForSequenceClassification)
|
||||
except ImportError:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
@ -28,7 +34,7 @@ from .configuration_common_test import ConfigTester
|
||||
class DistilBertModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes = (DistilBertModel, DistilBertForMaskedLM, DistilBertForQuestionAnswering,
|
||||
DistilBertForSequenceClassification)
|
||||
DistilBertForSequenceClassification) if is_torch_available() else None
|
||||
test_pruning = True
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = True
|
||||
|
@ -20,9 +20,13 @@ import unittest
|
||||
import pytest
|
||||
import shutil
|
||||
|
||||
from pytorch_transformers import is_torch_available
|
||||
|
||||
from pytorch_transformers import (GPT2Config, GPT2Model, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
GPT2LMHeadModel, GPT2DoubleHeadsModel)
|
||||
try:
|
||||
from pytorch_transformers import (GPT2Config, GPT2Model, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
GPT2LMHeadModel, GPT2DoubleHeadsModel)
|
||||
except ImportError:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
@ -30,7 +34,7 @@ from .configuration_common_test import ConfigTester
|
||||
|
||||
class GPT2ModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes = (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel)
|
||||
all_model_classes = (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
||||
|
||||
class GPT2ModelTester(object):
|
||||
|
||||
|
@ -20,9 +20,13 @@ import unittest
|
||||
import pytest
|
||||
import shutil
|
||||
|
||||
from pytorch_transformers import is_torch_available
|
||||
|
||||
from pytorch_transformers import (OpenAIGPTConfig, OpenAIGPTModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
|
||||
try:
|
||||
from pytorch_transformers import (OpenAIGPTConfig, OpenAIGPTModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
|
||||
except ImportError:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
@ -30,7 +34,7 @@ from .configuration_common_test import ConfigTester
|
||||
|
||||
class OpenAIGPTModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes = (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
|
||||
all_model_classes = (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) if is_torch_available() else ()
|
||||
|
||||
class OpenAIGPTModelTester(object):
|
||||
|
||||
|
@ -19,10 +19,15 @@ from __future__ import print_function
|
||||
import unittest
|
||||
import shutil
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import (RobertaConfig, RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification)
|
||||
from pytorch_transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from pytorch_transformers import is_torch_available
|
||||
|
||||
try:
|
||||
import torch
|
||||
from pytorch_transformers import (RobertaConfig, RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification)
|
||||
from pytorch_transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
except ImportError:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
@ -30,7 +35,7 @@ from .configuration_common_test import ConfigTester
|
||||
|
||||
class RobertaModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes = (RobertaForMaskedLM, RobertaModel)
|
||||
all_model_classes = (RobertaForMaskedLM, RobertaModel) if is_torch_available() else ()
|
||||
|
||||
class RobertaModelTester(object):
|
||||
|
||||
|
@ -24,21 +24,27 @@ import sys
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
|
||||
from pytorch_transformers import BertConfig, is_tf_available
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
|
||||
from pytorch_transformers import (BertConfig)
|
||||
from pytorch_transformers.modeling_tf_bert import TFBertModel, TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from pytorch_transformers.modeling_tf_bert import (TFBertModel, TFBertForMaskedLM,
|
||||
TFBertForNextSentencePrediction,
|
||||
TFBertForPreTraining,
|
||||
TFBertForSequenceClassification,
|
||||
TFBertForMultipleChoice,
|
||||
TFBertForTokenClassification,
|
||||
TFBertForQuestionAnswering,
|
||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
except ImportError:
|
||||
pass
|
||||
pytestmark = pytest.mark.skip("Require TensorFlow")
|
||||
|
||||
|
||||
class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
all_model_classes = (TFBertModel,)
|
||||
# BertForMaskedLM, BertForNextSentencePrediction,
|
||||
# BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
|
||||
# BertForTokenClassification)
|
||||
all_model_classes = (TFBertModel, TFBertForMaskedLM, TFBertForNextSentencePrediction,
|
||||
TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification,
|
||||
TFBertForTokenClassification) if is_tf_available() else ()
|
||||
|
||||
class TFBertModelTester(object):
|
||||
|
||||
@ -123,14 +129,8 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
|
||||
def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = TFBertModel(config=config)
|
||||
# model.eval()
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
@ -152,125 +152,115 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
|
||||
def create_and_check_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
pass
|
||||
# model = BertForMaskedLM(config=config)
|
||||
# model.eval()
|
||||
# loss, prediction_scores = model(input_ids, token_type_ids, input_mask, token_labels)
|
||||
# result = {
|
||||
# "loss": loss,
|
||||
# "prediction_scores": prediction_scores,
|
||||
# }
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["prediction_scores"].size()),
|
||||
# [self.batch_size, self.seq_length, self.vocab_size])
|
||||
# self.check_loss_output(result)
|
||||
model = TFBertForMaskedLM(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
prediction_scores, = model(inputs)
|
||||
result = {
|
||||
"prediction_scores": prediction_scores.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].shape),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
|
||||
|
||||
def create_and_check_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
pass
|
||||
# model = BertForNextSentencePrediction(config=config)
|
||||
# model.eval()
|
||||
# loss, seq_relationship_score = model(input_ids, token_type_ids, input_mask, sequence_labels)
|
||||
# result = {
|
||||
# "loss": loss,
|
||||
# "seq_relationship_score": seq_relationship_score,
|
||||
# }
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["seq_relationship_score"].size()),
|
||||
# [self.batch_size, 2])
|
||||
# self.check_loss_output(result)
|
||||
model = TFBertForNextSentencePrediction(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
seq_relationship_score, = model(inputs)
|
||||
result = {
|
||||
"seq_relationship_score": seq_relationship_score.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["seq_relationship_score"].shape),
|
||||
[self.batch_size, 2])
|
||||
|
||||
|
||||
def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
pass
|
||||
# model = BertForPreTraining(config=config)
|
||||
# model.eval()
|
||||
# loss, prediction_scores, seq_relationship_score = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
|
||||
# result = {
|
||||
# "loss": loss,
|
||||
# "prediction_scores": prediction_scores,
|
||||
# "seq_relationship_score": seq_relationship_score,
|
||||
# }
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["prediction_scores"].size()),
|
||||
# [self.batch_size, self.seq_length, self.vocab_size])
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["seq_relationship_score"].size()),
|
||||
# [self.batch_size, 2])
|
||||
# self.check_loss_output(result)
|
||||
|
||||
|
||||
def create_and_check_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
pass
|
||||
# model = BertForQuestionAnswering(config=config)
|
||||
# model.eval()
|
||||
# loss, start_logits, end_logits = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
|
||||
# result = {
|
||||
# "loss": loss,
|
||||
# "start_logits": start_logits,
|
||||
# "end_logits": end_logits,
|
||||
# }
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["start_logits"].size()),
|
||||
# [self.batch_size, self.seq_length])
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["end_logits"].size()),
|
||||
# [self.batch_size, self.seq_length])
|
||||
# self.check_loss_output(result)
|
||||
model = TFBertForPreTraining(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
prediction_scores, seq_relationship_score = model(inputs)
|
||||
result = {
|
||||
"prediction_scores": prediction_scores.numpy(),
|
||||
"seq_relationship_score": seq_relationship_score.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].shape),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(result["seq_relationship_score"].shape),
|
||||
[self.batch_size, 2])
|
||||
|
||||
|
||||
def create_and_check_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
pass
|
||||
# config.num_labels = self.num_labels
|
||||
# model = BertForSequenceClassification(config)
|
||||
# model.eval()
|
||||
# loss, logits = model(input_ids, token_type_ids, input_mask, sequence_labels)
|
||||
# result = {
|
||||
# "loss": loss,
|
||||
# "logits": logits,
|
||||
# }
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["logits"].size()),
|
||||
# [self.batch_size, self.num_labels])
|
||||
# self.check_loss_output(result)
|
||||
|
||||
|
||||
def create_and_check_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
pass
|
||||
# config.num_labels = self.num_labels
|
||||
# model = BertForTokenClassification(config=config)
|
||||
# model.eval()
|
||||
# loss, logits = model(input_ids, token_type_ids, input_mask, token_labels)
|
||||
# result = {
|
||||
# "loss": loss,
|
||||
# "logits": logits,
|
||||
# }
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["logits"].size()),
|
||||
# [self.batch_size, self.seq_length, self.num_labels])
|
||||
# self.check_loss_output(result)
|
||||
config.num_labels = self.num_labels
|
||||
model = TFBertForSequenceClassification(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
logits, = model(inputs)
|
||||
result = {
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].shape),
|
||||
[self.batch_size, self.num_labels])
|
||||
|
||||
|
||||
def create_and_check_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
pass
|
||||
# config.num_choices = self.num_choices
|
||||
# model = BertForMultipleChoice(config=config)
|
||||
# model.eval()
|
||||
# multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
# multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
# multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
# loss, logits = model(multiple_choice_inputs_ids,
|
||||
# multiple_choice_token_type_ids,
|
||||
# multiple_choice_input_mask,
|
||||
# choice_labels)
|
||||
# result = {
|
||||
# "loss": loss,
|
||||
# "logits": logits,
|
||||
# }
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["logits"].size()),
|
||||
# [self.batch_size, self.num_choices])
|
||||
# self.check_loss_output(result)
|
||||
config.num_choices = self.num_choices
|
||||
model = TFBertForMultipleChoice(config=config)
|
||||
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
|
||||
inputs = {'input_ids': multiple_choice_inputs_ids,
|
||||
'attention_mask': multiple_choice_input_mask,
|
||||
'token_type_ids': multiple_choice_token_type_ids}
|
||||
logits, = model(inputs)
|
||||
result = {
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].shape),
|
||||
[self.batch_size, self.num_choices])
|
||||
|
||||
|
||||
def create_and_check_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFBertForTokenClassification(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
logits, = model(inputs)
|
||||
result = {
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].shape),
|
||||
[self.batch_size, self.seq_length, self.num_labels])
|
||||
|
||||
|
||||
def create_and_check_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = TFBertForQuestionAnswering(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
start_logits, end_logits = model(inputs)
|
||||
result = {
|
||||
"start_logits": start_logits.numpy(),
|
||||
"end_logits": end_logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_logits"].shape),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_logits"].shape),
|
||||
[self.batch_size, self.seq_length])
|
||||
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
@ -287,48 +277,39 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_bert_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_model(*config_and_inputs)
|
||||
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
|
||||
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_for_next_sequence_prediction(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs)
|
||||
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_for_pretraining(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs)
|
||||
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_question_answering(*config_and_inputs)
|
||||
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs)
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
|
@ -30,7 +30,7 @@ try:
|
||||
from pytorch_transformers import TFPreTrainedModel
|
||||
# from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
except ImportError:
|
||||
pass
|
||||
pytestmark = pytest.mark.skip("Require TensorFlow")
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
@ -50,7 +50,6 @@ class TFCommonTestCases:
|
||||
test_pruning = True
|
||||
test_resize_embeddings = True
|
||||
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -64,7 +63,6 @@ class TFCommonTestCases:
|
||||
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
|
||||
|
||||
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_attention_outputs(self):
|
||||
pass
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -105,7 +103,6 @@ class TFCommonTestCases:
|
||||
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
|
||||
|
||||
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_headmasking(self):
|
||||
pass
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -153,7 +150,6 @@ class TFCommonTestCases:
|
||||
# attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
|
||||
|
||||
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_head_pruning(self):
|
||||
pass
|
||||
# if not self.test_pruning:
|
||||
@ -181,7 +177,6 @@ class TFCommonTestCases:
|
||||
# attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||
|
||||
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -201,7 +196,6 @@ class TFCommonTestCases:
|
||||
# [self.model_tester.seq_length, self.model_tester.hidden_size])
|
||||
|
||||
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
# original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -238,7 +232,6 @@ class TFCommonTestCases:
|
||||
# self.assertTrue(models_equal)
|
||||
|
||||
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_tie_model_weights(self):
|
||||
pass
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
@ -21,17 +21,21 @@ import random
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
from pytorch_transformers import is_torch_available
|
||||
|
||||
from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
|
||||
from pytorch_transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
try:
|
||||
import torch
|
||||
from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
|
||||
from pytorch_transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
except ImportError:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
|
||||
class TransfoXLModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel)
|
||||
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
|
@ -20,8 +20,14 @@ import unittest
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification)
|
||||
from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from pytorch_transformers import is_torch_available
|
||||
|
||||
try:
|
||||
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering,
|
||||
XLMForSequenceClassification)
|
||||
from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
except ImportError:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
@ -29,9 +35,9 @@ from .configuration_common_test import ConfigTester
|
||||
|
||||
class XLMModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes = (XLMModel, XLMWithLMHeadModel,
|
||||
XLMForQuestionAnswering, XLMForSequenceClassification)
|
||||
# , XLMForSequenceClassification, XLMForTokenClassification),
|
||||
all_model_classes = (XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering,
|
||||
XLMForSequenceClassification) if is_torch_available() else ()
|
||||
|
||||
|
||||
class XLMModelTester(object):
|
||||
|
||||
|
@ -23,10 +23,15 @@ import random
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
from pytorch_transformers import is_torch_available
|
||||
|
||||
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
|
||||
from pytorch_transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
try:
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
|
||||
from pytorch_transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
except ImportError:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
@ -34,7 +39,7 @@ from .configuration_common_test import ConfigTester
|
||||
class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes=(XLNetModel, XLNetLMHeadModel,
|
||||
XLNetForSequenceClassification, XLNetForQuestionAnswering)
|
||||
XLNetForSequenceClassification, XLNetForQuestionAnswering) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
|
||||
class XLNetModelTester(object):
|
||||
|
@ -18,11 +18,17 @@ from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
from pytorch_transformers import is_torch_available
|
||||
|
||||
from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
|
||||
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
|
||||
try:
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
|
||||
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
|
||||
except ImportError:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
|
||||
from .tokenization_tests_commons import TemporaryDirectory
|
||||
|
||||
@ -71,8 +77,8 @@ class OptimizationTest(unittest.TestCase):
|
||||
|
||||
|
||||
class ScheduleInitTest(unittest.TestCase):
|
||||
m = torch.nn.Linear(50, 50)
|
||||
optimizer = AdamW(m.parameters(), lr=10.)
|
||||
m = torch.nn.Linear(50, 50) if is_torch_available() else None
|
||||
optimizer = AdamW(m.parameters(), lr=10.) if is_torch_available() else None
|
||||
num_steps = 10
|
||||
|
||||
def assertListAlmostEqual(self, list1, list2, tol):
|
||||
|
@ -22,20 +22,19 @@ import pytest
|
||||
import logging
|
||||
|
||||
from pytorch_transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer
|
||||
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from pytorch_transformers.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from pytorch_transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
|
||||
class AutoTokenizerTest(unittest.TestCase):
|
||||
def test_tokenizer_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.assertIsNotNone(tokenizer)
|
||||
self.assertIsInstance(tokenizer, BertTokenizer)
|
||||
self.assertGreater(len(tokenizer), 0)
|
||||
|
||||
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in list(GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.assertIsNotNone(tokenizer)
|
||||
self.assertIsInstance(tokenizer, GPT2Tokenizer)
|
||||
|
@ -16,15 +16,21 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import pytest
|
||||
from io import open
|
||||
|
||||
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
|
||||
from pytorch_transformers import is_torch_available
|
||||
|
||||
from.tokenization_tests_commons import CommonTestCases
|
||||
try:
|
||||
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
|
||||
except ImportError:
|
||||
pytestmark = pytest.mark.skip("Require Torch") # TODO: untangle Transfo-XL tokenizer from torch.load and torch.save
|
||||
|
||||
from .tokenization_tests_commons import CommonTestCases
|
||||
|
||||
class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
|
||||
tokenizer_class = TransfoXLTokenizer
|
||||
tokenizer_class = TransfoXLTokenizer if is_torch_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super(TransfoXLTokenizationTest, self).setUp()
|
||||
|
@ -26,16 +26,20 @@ import sys
|
||||
from collections import Counter, OrderedDict
|
||||
from io import open
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# if sys.version_info[0] == 2:
|
||||
# import cPickle as pickle
|
||||
# else:
|
||||
# import pickle
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
Loading…
Reference in New Issue
Block a user