This commit is contained in:
thomwolf 2019-09-23 22:08:10 +02:00
commit 1e47dee24c
2 changed files with 11 additions and 7 deletions

View File

@ -23,8 +23,7 @@ import logging
from pytorch_transformers import is_tf_available
# if is_tf_available():
if False:
if is_tf_available():
from pytorch_transformers import (AutoConfig, BertConfig,
TFAutoModel, TFBertModel,
TFAutoModelWithLMHead, TFBertForMaskedLM,
@ -44,7 +43,8 @@ class TFAutoModelTest(unittest.TestCase):
self.assertTrue(h5py.version.hdf5_version.startswith("1.10"))
logging.basicConfig(level=logging.INFO)
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
@ -55,7 +55,8 @@ class TFAutoModelTest(unittest.TestCase):
def test_lmhead_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
@ -66,7 +67,8 @@ class TFAutoModelTest(unittest.TestCase):
def test_sequence_classification_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
@ -77,7 +79,8 @@ class TFAutoModelTest(unittest.TestCase):
def test_question_answering_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)

View File

@ -316,7 +316,8 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
model = TFBertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)