mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge branch 'tf2' of https://github.com/huggingface/pytorch-transformers into tf2
This commit is contained in:
commit
1e47dee24c
@ -23,8 +23,7 @@ import logging
|
||||
|
||||
from pytorch_transformers import is_tf_available
|
||||
|
||||
# if is_tf_available():
|
||||
if False:
|
||||
if is_tf_available():
|
||||
from pytorch_transformers import (AutoConfig, BertConfig,
|
||||
TFAutoModel, TFBertModel,
|
||||
TFAutoModelWithLMHead, TFBertForMaskedLM,
|
||||
@ -44,7 +43,8 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
self.assertTrue(h5py.version.hdf5_version.startswith("1.10"))
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ['bert-base-uncased']:
|
||||
config = AutoConfig.from_pretrained(model_name, force_download=True)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
@ -55,7 +55,8 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
|
||||
def test_lmhead_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ['bert-base-uncased']:
|
||||
config = AutoConfig.from_pretrained(model_name, force_download=True)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
@ -66,7 +67,8 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
|
||||
def test_sequence_classification_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ['bert-base-uncased']:
|
||||
config = AutoConfig.from_pretrained(model_name, force_download=True)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
@ -77,7 +79,8 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
|
||||
def test_question_answering_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ['bert-base-uncased']:
|
||||
config = AutoConfig.from_pretrained(model_name, force_download=True)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
@ -316,7 +316,8 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ['bert-base-uncased']:
|
||||
model = TFBertModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
Loading…
Reference in New Issue
Block a user