mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Take advantage of the cache when running tests.
Caching models across test cases and across runs of the test suite makes slow tests somewhat more bearable. Use gettempdir() instead of /tmp in tests. This makes it easier to change the location of the cache with semi-standard TMPDIR/TEMP/TMP environment variables. Fix #2222.
This commit is contained in:
parent
b67fa1a8d2
commit
b670c26684
@ -17,12 +17,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import XxxConfig, is_tf_available
|
||||
|
||||
@ -245,10 +244,8 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in ['xxx-base-uncased']:
|
||||
model = TFXxxModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,13 +17,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (XxxConfig, XxxModel, XxxForMaskedLM,
|
||||
@ -249,10 +248,8 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = XxxModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,13 +17,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (AlbertConfig, AlbertModel, AlbertForMaskedLM,
|
||||
@ -230,10 +229,8 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = AlbertModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = AlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,13 +17,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (BertConfig, BertModel, BertForMaskedLM,
|
||||
@ -360,10 +359,8 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = BertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -30,7 +30,7 @@ import logging
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@ -753,10 +753,8 @@ class CommonTestCases:
|
||||
[[], []])
|
||||
|
||||
def create_and_check_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]:
|
||||
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = self.base_model_class.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.parent.assertIsNotNone(model)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
|
@ -16,7 +16,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import pdb
|
||||
|
||||
from transformers import is_torch_available
|
||||
@ -27,7 +26,7 @@ if is_torch_available():
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -205,10 +204,8 @@ class CTRLModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = CTRLModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = CTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -27,7 +27,7 @@ if is_torch_available():
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -235,10 +235,8 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
# @slow
|
||||
# def test_model_from_pretrained(self):
|
||||
# cache_dir = "/tmp/transformers_test/"
|
||||
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
# shutil.rmtree(cache_dir)
|
||||
# model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
# self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,7 +17,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
@ -27,7 +26,7 @@ if is_torch_available():
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -239,10 +238,8 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = GPT2Model.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = GPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -17,7 +17,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
@ -27,7 +26,7 @@ if is_torch_available():
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -207,10 +206,8 @@ class OpenAIGPTModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -17,7 +17,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
@ -29,7 +28,7 @@ if is_torch_available():
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -199,10 +198,8 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = RobertaModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = RobertaModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -17,13 +17,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (T5Config, T5Model, T5WithLMHeadModel)
|
||||
@ -175,10 +174,8 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = T5Model.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = T5Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,12 +17,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import AlbertConfig, is_tf_available
|
||||
|
||||
@ -217,12 +216,9 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
# for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ['albert-base-uncased']:
|
||||
model = TFAlbertModel.from_pretrained(
|
||||
model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFAlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -46,11 +46,11 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ['bert-base-uncased']:
|
||||
config = AutoConfig.from_pretrained(model_name, force_download=True)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
||||
model = TFAutoModel.from_pretrained(model_name, force_download=True)
|
||||
model = TFAutoModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, TFBertModel)
|
||||
|
||||
@ -59,11 +59,11 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ['bert-base-uncased']:
|
||||
config = AutoConfig.from_pretrained(model_name, force_download=True)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
||||
model = TFAutoModelWithLMHead.from_pretrained(model_name, force_download=True)
|
||||
model = TFAutoModelWithLMHead.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, TFBertForMaskedLM)
|
||||
|
||||
@ -72,11 +72,11 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ['bert-base-uncased']:
|
||||
config = AutoConfig.from_pretrained(model_name, force_download=True)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
||||
model = TFAutoModelForSequenceClassification.from_pretrained(model_name, force_download=True)
|
||||
model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, TFBertForSequenceClassification)
|
||||
|
||||
@ -85,17 +85,17 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ['bert-base-uncased']:
|
||||
config = AutoConfig.from_pretrained(model_name, force_download=True)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
||||
model = TFAutoModelForQuestionAnswering.from_pretrained(model_name, force_download=True)
|
||||
model = TFAutoModelForQuestionAnswering.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, TFBertForQuestionAnswering)
|
||||
|
||||
def test_from_pretrained_identifier(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, force_download=True)
|
||||
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||
self.assertIsInstance(model, TFBertForMaskedLM)
|
||||
|
||||
|
||||
|
@ -17,12 +17,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import BertConfig, is_tf_available
|
||||
|
||||
@ -310,11 +309,9 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ['bert-base-uncased']:
|
||||
model = TFBertModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,12 +17,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import CTRLConfig, is_tf_available
|
||||
|
||||
@ -189,10 +188,8 @@ class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFCTRLModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFCTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -20,7 +20,7 @@ import unittest
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import DistilBertConfig, is_tf_available
|
||||
|
||||
@ -211,10 +211,8 @@ class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
# @slow
|
||||
# def test_model_from_pretrained(self):
|
||||
# cache_dir = "/tmp/transformers_test/"
|
||||
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
# shutil.rmtree(cache_dir)
|
||||
# model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
# self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,12 +17,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import GPT2Config, is_tf_available
|
||||
|
||||
@ -220,10 +219,8 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFGPT2Model.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFGPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,12 +17,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import OpenAIGPTConfig, is_tf_available
|
||||
|
||||
@ -219,10 +218,8 @@ class TFOpenAIGPTModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFOpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFOpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,11 +17,10 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import RobertaConfig, is_tf_available
|
||||
|
||||
@ -192,10 +191,8 @@ class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFRobertaModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFRobertaModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -17,12 +17,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import T5Config, is_tf_available
|
||||
|
||||
@ -162,10 +161,8 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in ['t5-small']:
|
||||
model = TFT5Model.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFT5Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -18,11 +18,10 @@ from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import random
|
||||
import shutil
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import TransfoXLConfig, is_tf_available
|
||||
|
||||
@ -205,10 +204,8 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFTransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFTransfoXLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -17,7 +17,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from transformers import is_tf_available
|
||||
|
||||
@ -31,7 +30,7 @@ if is_tf_available():
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
|
||||
@require_tf
|
||||
@ -252,10 +251,8 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = XLMModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -20,7 +20,6 @@ import os
|
||||
import unittest
|
||||
import json
|
||||
import random
|
||||
import shutil
|
||||
|
||||
from transformers import XLNetConfig, is_tf_available
|
||||
|
||||
@ -35,7 +34,7 @@ if is_tf_available():
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
|
||||
@require_tf
|
||||
@ -319,10 +318,8 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFXLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFXLNetModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -18,7 +18,6 @@ from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import random
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
@ -29,7 +28,7 @@ if is_torch_available():
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -208,10 +207,8 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TransfoXLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -17,7 +17,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
@ -28,7 +27,7 @@ if is_torch_available():
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -318,10 +317,8 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = XLMModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -20,7 +20,6 @@ import os
|
||||
import unittest
|
||||
import json
|
||||
import random
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
@ -33,7 +32,7 @@ if is_torch_available():
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -385,10 +384,8 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = XLNetModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -1,11 +1,14 @@
|
||||
import os
|
||||
import unittest
|
||||
import tempfile
|
||||
|
||||
from distutils.util import strtobool
|
||||
|
||||
from transformers.file_utils import _tf_available, _torch_available
|
||||
|
||||
|
||||
CACHE_DIR = os.path.join(tempfile.gettempdir(), "transformers_test")
|
||||
|
||||
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user