mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
skipping tf tests if tf is not installed
This commit is contained in:
parent
134847db81
commit
aa4c8804f2
@ -287,48 +287,48 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_bert_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_model(*config_and_inputs)
|
||||
|
||||
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
|
||||
|
||||
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_for_next_sequence_prediction(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs)
|
||||
|
||||
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_for_pretraining(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs)
|
||||
|
||||
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_question_answering(*config_and_inputs)
|
||||
|
||||
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs)
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
|
@ -50,7 +50,7 @@ class TFCommonTestCases:
|
||||
test_pruning = True
|
||||
test_resize_embeddings = True
|
||||
|
||||
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -64,7 +64,7 @@ class TFCommonTestCases:
|
||||
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
|
||||
|
||||
|
||||
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_attention_outputs(self):
|
||||
pass
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -105,7 +105,7 @@ class TFCommonTestCases:
|
||||
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
|
||||
|
||||
|
||||
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_headmasking(self):
|
||||
pass
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -153,7 +153,7 @@ class TFCommonTestCases:
|
||||
# attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
|
||||
|
||||
|
||||
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_head_pruning(self):
|
||||
pass
|
||||
# if not self.test_pruning:
|
||||
@ -181,7 +181,7 @@ class TFCommonTestCases:
|
||||
# attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||
|
||||
|
||||
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -201,7 +201,7 @@ class TFCommonTestCases:
|
||||
# [self.model_tester.seq_length, self.model_tester.hidden_size])
|
||||
|
||||
|
||||
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
# original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -238,7 +238,7 @@ class TFCommonTestCases:
|
||||
# self.assertTrue(models_equal)
|
||||
|
||||
|
||||
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_tie_model_weights(self):
|
||||
pass
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -290,7 +290,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
|
||||
|
||||
|
||||
class TFModelUtilsTest(unittest.TestCase):
|
||||
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||
@pytest.mark.skipif('tensorflow' not in sys.modules, reason="requires TensorFlow")
|
||||
def test_model_from_pretrained(self):
|
||||
pass
|
||||
# logging.basicConfig(level=logging.INFO)
|
||||
|
Loading…
Reference in New Issue
Block a user