fix test when tf is not here

This commit is contained in:
thomwolf 2019-09-05 02:53:52 +02:00
parent 59fe641b8b
commit ad0ab9afe9
4 changed files with 101 additions and 46 deletions

View File

@ -11,6 +11,7 @@ jobs:
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install tensorboardX scikit-learn
- run: sudo pip install tensorflow==2.0.0-rc0
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
- run: python -m pytest -sv ./examples/
- run: codecov
@ -24,6 +25,7 @@ jobs:
- checkout
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install tensorflow==2.0.0-rc0
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
- run: codecov
deploy_doc:

View File

@ -1,4 +1,5 @@
__version__ = "1.2.0"
# Work around to update TensorFlow's absl.logging threshold which alters the
# default Python logging output behavior when present.
# see: https://github.com/abseil/abseil-py/issues/99
@ -11,6 +12,10 @@ try:
except:
pass
import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
# Tokenizer
from .tokenization_utils import (PreTrainedTokenizer)
from .tokenization_auto import AutoTokenizer
@ -36,38 +41,63 @@ from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCH
from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
# Modeling
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering,
AutoModelWithLMHead)
try:
import torch
torch_available = True # pylint: disable=invalid-name
except ImportError:
torch_available = False # pylint: disable=invalid-name
from .modeling_bert import (BertPreTrainedModel, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForMultipleChoice,
BertForTokenClassification, BertForQuestionAnswering,
load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_openai import (OpenAIGPTPreTrainedModel, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_transfo_xl import (TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel,
load_tf_weights_in_gpt2, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlnet import (XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
XLNetForSequenceClassification, XLNetForQuestionAnswering,
load_tf_weights_in_xlnet, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlm import (XLMPreTrainedModel , XLMModel,
XLMWithLMHeadModel, XLMForSequenceClassification,
XLMForQuestionAnswering, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
if torch_available:
logger.info("PyTorch version {} available.".format(torch.__version__))
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering,
AutoModelWithLMHead)
from .modeling_bert import (BertPreTrainedModel, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForMultipleChoice,
BertForTokenClassification, BertForQuestionAnswering,
load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_openai import (OpenAIGPTPreTrainedModel, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_transfo_xl import (TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel,
load_tf_weights_in_gpt2, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlnet import (XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
XLNetForSequenceClassification, XLNetForQuestionAnswering,
load_tf_weights_in_xlnet, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlm import (XLMPreTrainedModel , XLMModel,
XLMWithLMHeadModel, XLMForSequenceClassification,
XLMForQuestionAnswering, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
# Optimization
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
# TensorFlow
try:
import tensorflow as tf
tf_available = True # pylint: disable=invalid-name
except ImportError:
tf_available = False # pylint: disable=invalid-name
if tf_available:
logger.info("TensorFlow version {} available.".format(tf.__version__))
from .modeling_tf_utils import TFPreTrainedModel
from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertModel, TFBertForPreTraining,
TFBertForMaskedLM, TFBertForNextSentencePrediction, load_pt_weights_in_bert)
# Optimization
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
# Files and general utilities
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,

View File

@ -19,15 +19,19 @@ from __future__ import print_function
import unittest
import shutil
import pytest
import tensorflow as tf
from pytorch_transformers import (BertConfig)
from pytorch_transformers.modeling_tf_bert import TFBertModel, TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
try:
import tensorflow as tf
from pytorch_transformers import (BertConfig)
from pytorch_transformers.modeling_tf_bert import TFBertModel, TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
except ImportError:
pass
class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
@ -283,39 +287,48 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
def test_config(self):
self.config_tester.run_common_tests()
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_bert_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_model(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_for_next_sequence_prediction(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_for_pretraining(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_question_answering(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_sequence_classification(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs)
@pytest.mark.slow
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
@ -325,3 +338,4 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
if __name__ == "__main__":
unittest.main()

View File

@ -12,24 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import, division, print_function
import copy
import os
import shutil
import json
import logging
import random
import shutil
import unittest
import uuid
import unittest
import logging
import pytest
import sys
import tensorflow as tf
from pytorch_transformers import TFPreTrainedModel
# from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
try:
import tensorflow as tf
from pytorch_transformers import TFPreTrainedModel
# from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
except ImportError:
pass
def _config_zero_init(config):
@ -49,6 +50,7 @@ class TFCommonTestCases:
test_pruning = True
test_resize_embeddings = True
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_initialization(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -62,6 +64,7 @@ class TFCommonTestCases:
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_attention_outputs(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -102,6 +105,7 @@ class TFCommonTestCases:
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_headmasking(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -149,6 +153,7 @@ class TFCommonTestCases:
# attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_head_pruning(self):
pass
# if not self.test_pruning:
@ -176,6 +181,7 @@ class TFCommonTestCases:
# attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_hidden_states_output(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -195,6 +201,7 @@ class TFCommonTestCases:
# [self.model_tester.seq_length, self.model_tester.hidden_size])
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_resize_tokens_embeddings(self):
pass
# original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -231,6 +238,7 @@ class TFCommonTestCases:
# self.assertTrue(models_equal)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_tie_model_weights(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -282,6 +290,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
class TFModelUtilsTest(unittest.TestCase):
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_model_from_pretrained(self):
pass
# logging.basicConfig(level=logging.INFO)