fix test when tf is not here

This commit is contained in:
thomwolf 2019-09-05 02:53:52 +02:00
parent 59fe641b8b
commit ad0ab9afe9
4 changed files with 101 additions and 46 deletions

View File

@ -11,6 +11,7 @@ jobs:
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install tensorboardX scikit-learn
- run: sudo pip install tensorflow==2.0.0-rc0
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
- run: python -m pytest -sv ./examples/
- run: codecov
@ -24,6 +25,7 @@ jobs:
- checkout
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install tensorflow==2.0.0-rc0
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
- run: codecov
deploy_doc:

View File

@ -1,4 +1,5 @@
__version__ = "1.2.0"
# Work around to update TensorFlow's absl.logging threshold which alters the
# default Python logging output behavior when present.
# see: https://github.com/abseil/abseil-py/issues/99
@ -11,6 +12,10 @@ try:
except:
pass
import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
# Tokenizer
from .tokenization_utils import (PreTrainedTokenizer)
from .tokenization_auto import AutoTokenizer
@ -36,6 +41,15 @@ from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCH
from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
# Modeling
try:
import torch
torch_available = True # pylint: disable=invalid-name
except ImportError:
torch_available = False # pylint: disable=invalid-name
if torch_available:
logger.info("PyTorch version {} available.".format(torch.__version__))
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering,
AutoModelWithLMHead)
@ -69,6 +83,22 @@ from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
# TensorFlow
try:
import tensorflow as tf
tf_available = True # pylint: disable=invalid-name
except ImportError:
tf_available = False # pylint: disable=invalid-name
if tf_available:
logger.info("TensorFlow version {} available.".format(tf.__version__))
from .modeling_tf_utils import TFPreTrainedModel
from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertModel, TFBertForPreTraining,
TFBertForMaskedLM, TFBertForNextSentencePrediction, load_pt_weights_in_bert)
# Files and general utilities
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
cached_path, add_start_docstrings, add_end_docstrings,

View File

@ -19,14 +19,18 @@ from __future__ import print_function
import unittest
import shutil
import pytest
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
try:
import tensorflow as tf
from pytorch_transformers import (BertConfig)
from pytorch_transformers.modeling_tf_bert import TFBertModel, TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
except ImportError:
pass
class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
@ -283,39 +287,48 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
def test_config(self):
self.config_tester.run_common_tests()
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_bert_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_model(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_for_next_sequence_prediction(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_for_pretraining(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_question_answering(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_sequence_classification(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs)
@pytest.mark.slow
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
@ -325,3 +338,4 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
if __name__ == "__main__":
unittest.main()

View File

@ -12,24 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import, division, print_function
import copy
import os
import shutil
import json
import logging
import random
import shutil
import unittest
import uuid
import unittest
import logging
import pytest
import sys
try:
import tensorflow as tf
from pytorch_transformers import TFPreTrainedModel
# from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
except ImportError:
pass
def _config_zero_init(config):
@ -49,6 +50,7 @@ class TFCommonTestCases:
test_pruning = True
test_resize_embeddings = True
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_initialization(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -62,6 +64,7 @@ class TFCommonTestCases:
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_attention_outputs(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -102,6 +105,7 @@ class TFCommonTestCases:
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_headmasking(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -149,6 +153,7 @@ class TFCommonTestCases:
# attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_head_pruning(self):
pass
# if not self.test_pruning:
@ -176,6 +181,7 @@ class TFCommonTestCases:
# attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_hidden_states_output(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -195,6 +201,7 @@ class TFCommonTestCases:
# [self.model_tester.seq_length, self.model_tester.hidden_size])
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_resize_tokens_embeddings(self):
pass
# original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -231,6 +238,7 @@ class TFCommonTestCases:
# self.assertTrue(models_equal)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_tie_model_weights(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -282,6 +290,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
class TFModelUtilsTest(unittest.TestCase):
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
def test_model_from_pretrained(self):
pass
# logging.basicConfig(level=logging.INFO)