skipping tf tests if tf is not installed

This commit is contained in:
thomwolf 2019-09-05 03:06:09 +02:00
parent 134847db81
commit aa4c8804f2
2 changed files with 17 additions and 17 deletions

View File

@ -287,48 +287,48 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
def test_config(self):
self.config_tester.run_common_tests()
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_bert_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_model(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_for_next_sequence_prediction(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_for_pretraining(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_question_answering(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_sequence_classification(*config_and_inputs)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs)
@pytest.mark.slow
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:

View File

@ -50,7 +50,7 @@ class TFCommonTestCases:
test_pruning = True
test_resize_embeddings = True
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_initialization(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -64,7 +64,7 @@ class TFCommonTestCases:
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_attention_outputs(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -105,7 +105,7 @@ class TFCommonTestCases:
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_headmasking(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -153,7 +153,7 @@ class TFCommonTestCases:
# attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_head_pruning(self):
pass
# if not self.test_pruning:
@ -181,7 +181,7 @@ class TFCommonTestCases:
# attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_hidden_states_output(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -201,7 +201,7 @@ class TFCommonTestCases:
# [self.model_tester.seq_length, self.model_tester.hidden_size])
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_resize_tokens_embeddings(self):
pass
# original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -238,7 +238,7 @@ class TFCommonTestCases:
# self.assertTrue(models_equal)
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_tie_model_weights(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -290,7 +290,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
class TFModelUtilsTest(unittest.TestCase):
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
def test_model_from_pretrained(self):
pass
# logging.basicConfig(level=logging.INFO)