From 798da627ebb24cf729bb55575e69e5d8caf91332 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 23 Sep 2019 12:06:10 -0400 Subject: [PATCH] Fix TFBert tests in Python 3.5 --- .../tests/modeling_tf_auto_test.py | 15 +++++++++------ .../tests/modeling_tf_bert_test.py | 3 ++- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/pytorch_transformers/tests/modeling_tf_auto_test.py b/pytorch_transformers/tests/modeling_tf_auto_test.py index 70b13c54a8d..7b080bafcd6 100644 --- a/pytorch_transformers/tests/modeling_tf_auto_test.py +++ b/pytorch_transformers/tests/modeling_tf_auto_test.py @@ -23,8 +23,7 @@ import logging from pytorch_transformers import is_tf_available -# if is_tf_available(): -if False: +if is_tf_available(): from pytorch_transformers import (AutoConfig, BertConfig, TFAutoModel, TFBertModel, TFAutoModelWithLMHead, TFBertForMaskedLM, @@ -44,7 +43,8 @@ class TFAutoModelTest(unittest.TestCase): self.assertTrue(h5py.version.hdf5_version.startswith("1.10")) logging.basicConfig(level=logging.INFO) - for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + for model_name in ['bert-base-uncased']: config = AutoConfig.from_pretrained(model_name, force_download=True) self.assertIsNotNone(config) self.assertIsInstance(config, BertConfig) @@ -55,7 +55,8 @@ class TFAutoModelTest(unittest.TestCase): def test_lmhead_model_from_pretrained(self): logging.basicConfig(level=logging.INFO) - for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + for model_name in ['bert-base-uncased']: config = AutoConfig.from_pretrained(model_name, force_download=True) self.assertIsNotNone(config) self.assertIsInstance(config, BertConfig) @@ -66,7 +67,8 @@ class TFAutoModelTest(unittest.TestCase): def test_sequence_classification_model_from_pretrained(self): logging.basicConfig(level=logging.INFO) - for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + for model_name in ['bert-base-uncased']: config = AutoConfig.from_pretrained(model_name, force_download=True) self.assertIsNotNone(config) self.assertIsInstance(config, BertConfig) @@ -77,7 +79,8 @@ class TFAutoModelTest(unittest.TestCase): def test_question_answering_model_from_pretrained(self): logging.basicConfig(level=logging.INFO) - for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + for model_name in ['bert-base-uncased']: config = AutoConfig.from_pretrained(model_name, force_download=True) self.assertIsNotNone(config) self.assertIsInstance(config, BertConfig) diff --git a/pytorch_transformers/tests/modeling_tf_bert_test.py b/pytorch_transformers/tests/modeling_tf_bert_test.py index a1bf9bf794c..b12f113c0f5 100644 --- a/pytorch_transformers/tests/modeling_tf_bert_test.py +++ b/pytorch_transformers/tests/modeling_tf_bert_test.py @@ -316,7 +316,8 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester): @pytest.mark.slow def test_model_from_pretrained(self): cache_dir = "/tmp/pytorch_transformers_test/" - for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + for model_name in ['bert-base-uncased']: model = TFBertModel.from_pretrained(model_name, cache_dir=cache_dir) shutil.rmtree(cache_dir) self.assertIsNotNone(model)