Merge pull request #2217 from aaugustin/test-parallelization

Support running tests in parallel
This commit is contained in:
Thomas Wolf 2019-12-21 11:54:23 +01:00 committed by GitHub
commit fae4d1c266
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 174 additions and 226 deletions

View File

@ -1,9 +1,11 @@
version: 2
jobs:
build_py3_torch_and_tf:
run_tests_py3_torch_and_tf:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
parallelism: 1
steps:
@ -11,49 +13,67 @@ jobs:
- run: sudo pip install torch
- run: sudo pip install tensorflow
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install pytest codecov pytest-cov pytest-xdist
- run: sudo pip install tensorboardX scikit-learn
- run: python -m pytest -sv ./transformers/tests/ --cov
- run: python -m pytest -n 8 --dist=loadfile -s -v ./transformers/tests/ --cov
- run: codecov
build_py3_torch:
run_tests_py3_torch:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
parallelism: 1
steps:
- checkout
- run: sudo pip install torch
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install pytest codecov pytest-cov pytest-xdist
- run: sudo pip install tensorboardX scikit-learn
- run: python -m pytest -sv ./transformers/tests/ --cov
- run: python -m pytest -sv ./examples/
- run: python -m pytest -n 8 --dist=loadfile -s -v ./transformers/tests/ --cov
- run: codecov
build_py3_tf:
run_tests_py3_tf:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
parallelism: 1
steps:
- checkout
- run: sudo pip install tensorflow
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install pytest codecov pytest-cov pytest-xdist
- run: sudo pip install tensorboardX scikit-learn
- run: python -m pytest -sv ./transformers/tests/ --cov
- run: python -m pytest -n 8 --dist=loadfile -s -v ./transformers/tests/ --cov
- run: codecov
build_py3_custom_tokenizers:
run_tests_py3_custom_tokenizers:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
steps:
- checkout
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest
- run: sudo pip install pytest pytest-xdist
- run: sudo pip install mecab-python3
- run: RUN_CUSTOM_TOKENIZERS=1 python -m pytest -sv ./transformers/tests/tokenization_bert_japanese_test.py
run_examples_py3_torch:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
parallelism: 1
steps:
- checkout
- run: sudo pip install torch
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest pytest-xdist
- run: sudo pip install tensorboardX scikit-learn
- run: python -m pytest -n 8 --dist=loadfile -s -v ./examples/
deploy_doc:
working_directory: ~/transformers
docker:
@ -66,7 +86,7 @@ jobs:
- run: sudo pip install --progress-bar off -r docs/requirements.txt
- run: sudo pip install --progress-bar off -r requirements.txt
- run: ./.circleci/deploy.sh
repository_consistency:
check_repository_consistency:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
@ -85,9 +105,10 @@ workflows:
version: 2
build_and_test:
jobs:
- repository_consistency
- build_py3_custom_tokenizers
- build_py3_torch_and_tf
- build_py3_torch
- build_py3_tf
- check_repository_consistency
- run_examples_py3_torch
- run_tests_py3_custom_tokenizers
- run_tests_py3_torch_and_tf
- run_tests_py3_torch
- run_tests_py3_tf
- deploy_doc: *workflow_filters

View File

@ -59,6 +59,7 @@ setup(
"tests.*", "tests"]),
install_requires=['numpy',
'boto3',
'filelock',
'requests',
'tqdm',
'regex != 2019.12.17',

View File

@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import XxxConfig, is_tf_available
@ -245,10 +244,8 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in ['xxx-base-uncased']:
model = TFXxxModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":

View File

@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
from transformers import (XxxConfig, XxxModel, XxxForMaskedLM,
@ -249,10 +248,8 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XxxModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":

View File

@ -10,10 +10,9 @@ import json
import logging
import os
import six
import shutil
import tempfile
import fnmatch
from functools import wraps
from functools import partial, wraps
from hashlib import sha256
from io import open
@ -25,6 +24,8 @@ from tqdm.auto import tqdm
from contextlib import contextmanager
from . import __version__
from filelock import FileLock
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
try:
@ -334,59 +335,60 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
# If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one
if not os.path.exists(cache_path) and etag is None:
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
matching_files = [
file
for file in fnmatch.filter(os.listdir(cache_dir), filename + '.*')
if not file.endswith('.json') and not file.endswith('.lock')
]
if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1])
if resume_download:
incomplete_path = cache_path + '.incomplete'
@contextmanager
def _resumable_file_manager():
with open(incomplete_path,'a+b') as f:
yield f
os.remove(incomplete_path)
temp_file_manager = _resumable_file_manager
if os.path.exists(incomplete_path):
resume_size = os.stat(incomplete_path).st_size
else:
resume_size = 0
else:
temp_file_manager = tempfile.NamedTemporaryFile
resume_size = 0
# Prevent parallel downloads of the same file with a lock.
lock_path = cache_path + '.lock'
with FileLock(lock_path):
if etag is not None and (not os.path.exists(cache_path) or force_download):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file:
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
# GET file object
if url.startswith("s3://"):
if resume_download:
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
s3_get(url, temp_file, proxies=proxies)
if resume_download:
incomplete_path = cache_path + '.incomplete'
@contextmanager
def _resumable_file_manager():
with open(incomplete_path,'a+b') as f:
yield f
temp_file_manager = _resumable_file_manager
if os.path.exists(incomplete_path):
resume_size = os.stat(incomplete_path).st_size
else:
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
resume_size = 0
else:
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
resume_size = 0
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)
if etag is not None and (not os.path.exists(cache_path) or force_download):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file:
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
with open(cache_path, 'wb') as cache_file:
shutil.copyfileobj(temp_file, cache_file)
# GET file object
if url.startswith("s3://"):
if resume_download:
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
s3_get(url, temp_file, proxies=proxies)
else:
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json'
with open(meta_path, 'w') as meta_file:
output_string = json.dumps(meta)
if sys.version_info[0] == 2 and isinstance(output_string, str):
output_string = unicode(output_string, 'utf-8') # The beauty of python 2
meta_file.write(output_string)
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
logger.info("removing temp file %s", temp_file.name)
logger.info("storing %s in cache at %s", url, cache_path)
os.rename(temp_file.name, cache_path)
logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json'
with open(meta_path, 'w') as meta_file:
output_string = json.dumps(meta)
if sys.version_info[0] == 2 and isinstance(output_string, str):
output_string = unicode(output_string, 'utf-8') # The beauty of python 2
meta_file.write(output_string)
return cache_path

View File

@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
from transformers import (AlbertConfig, AlbertModel, AlbertForMaskedLM,
@ -230,10 +229,8 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = AlbertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = AlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":

View File

@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
from transformers import (BertConfig, BertModel, BertForMaskedLM,
@ -360,10 +359,8 @@ class BertModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = BertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)

View File

@ -18,7 +18,7 @@ from __future__ import print_function
import copy
import sys
import os
import os.path
import shutil
import tempfile
import json
@ -30,7 +30,7 @@ import logging
from transformers import is_torch_available
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
import torch
@ -218,21 +218,22 @@ class CommonTestCases:
inputs = inputs_dict['input_ids'] # Let's keep only input_ids
try:
torch.jit.trace(model, inputs)
traced_gpt2 = torch.jit.trace(model, inputs)
except RuntimeError:
self.fail("Couldn't trace module.")
try:
traced_gpt2 = torch.jit.trace(model, inputs)
torch.jit.save(traced_gpt2, "traced_model.pt")
except RuntimeError:
self.fail("Couldn't save module.")
with TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
try:
loaded_model = torch.jit.load("traced_model.pt")
os.remove("traced_model.pt")
except ValueError:
self.fail("Couldn't load module.")
try:
torch.jit.save(traced_gpt2, pt_file_name)
except Exception:
self.fail("Couldn't save module.")
try:
loaded_model = torch.jit.load(pt_file_name)
except Exception:
self.fail("Couldn't load module.")
model.to(torch_device)
model.eval()
@ -352,12 +353,11 @@ class CommonTestCases:
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)),
-1: [0]}
model.prune_heads(heads_to_prune)
directory = "pruned_model"
if not os.path.exists(directory):
os.makedirs(directory)
model.save_pretrained(directory)
model = model_class.from_pretrained(directory)
model.to(torch_device)
with TemporaryDirectory() as temp_dir_name:
model.save_pretrained(temp_dir_name)
model = model_class.from_pretrained(temp_dir_name)
model.to(torch_device)
with torch.no_grad():
outputs = model(**inputs_dict)
@ -366,7 +366,6 @@ class CommonTestCases:
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
shutil.rmtree(directory)
def test_head_pruning_save_load_from_config_init(self):
if not self.test_pruning:
@ -426,14 +425,10 @@ class CommonTestCases:
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
directory = "pruned_model"
if not os.path.exists(directory):
os.makedirs(directory)
model.save_pretrained(directory)
model = model_class.from_pretrained(directory)
model.to(torch_device)
shutil.rmtree(directory)
with TemporaryDirectory() as temp_dir_name:
model.save_pretrained(temp_dir_name)
model = model_class.from_pretrained(temp_dir_name)
model.to(torch_device)
with torch.no_grad():
outputs = model(**inputs_dict)
@ -758,10 +753,8 @@ class CommonTestCases:
[[], []])
def create_and_check_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]:
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = self.base_model_class.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.parent.assertIsNotNone(model)
def prepare_config_and_inputs_for_common(self):

View File

@ -16,7 +16,6 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import pdb
from transformers import is_torch_available
@ -27,7 +26,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
@ -205,10 +204,8 @@ class CTRLModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = CTRLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = CTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)

View File

@ -27,7 +27,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
@ -235,10 +235,8 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
# @slow
# def test_model_from_pretrained(self):
# cache_dir = "/tmp/transformers_test/"
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
# shutil.rmtree(cache_dir)
# model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
# self.assertIsNotNone(model)
if __name__ == "__main__":

View File

@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
@ -27,7 +26,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
@ -239,10 +238,8 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = GPT2Model.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = GPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)

View File

@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
@ -27,7 +26,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
@ -207,10 +206,8 @@ class OpenAIGPTModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)

View File

@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
@ -29,7 +28,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
@ -199,10 +198,8 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = RobertaModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = RobertaModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)

View File

@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
from transformers import (T5Config, T5Model, T5WithLMHeadModel)
@ -175,10 +174,8 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = T5Model.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = T5Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":

View File

@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import AlbertConfig, is_tf_available
@ -217,12 +216,8 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
# for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['albert-base-uncased']:
model = TFAlbertModel.from_pretrained(
model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFAlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)

View File

@ -46,11 +46,11 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
model = TFAutoModel.from_pretrained(model_name, force_download=True)
model = TFAutoModel.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertModel)
@ -59,11 +59,11 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
model = TFAutoModelWithLMHead.from_pretrained(model_name, force_download=True)
model = TFAutoModelWithLMHead.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertForMaskedLM)
@ -72,11 +72,11 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
model = TFAutoModelForSequenceClassification.from_pretrained(model_name, force_download=True)
model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertForSequenceClassification)
@ -85,17 +85,17 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
model = TFAutoModelForQuestionAnswering.from_pretrained(model_name, force_download=True)
model = TFAutoModelForQuestionAnswering.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertForQuestionAnswering)
def test_from_pretrained_identifier(self):
logging.basicConfig(level=logging.INFO)
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, force_download=True)
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, TFBertForMaskedLM)

View File

@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import BertConfig, is_tf_available
@ -310,11 +309,9 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
model = TFBertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":

View File

@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import CTRLConfig, is_tf_available
@ -189,10 +188,8 @@ class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFCTRLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFCTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":

View File

@ -20,7 +20,7 @@ import unittest
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import DistilBertConfig, is_tf_available
@ -211,10 +211,8 @@ class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester):
# @slow
# def test_model_from_pretrained(self):
# cache_dir = "/tmp/transformers_test/"
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
# shutil.rmtree(cache_dir)
# model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
# self.assertIsNotNone(model)
if __name__ == "__main__":

View File

@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import GPT2Config, is_tf_available
@ -220,10 +219,8 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFGPT2Model.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFGPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":

View File

@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import OpenAIGPTConfig, is_tf_available
@ -219,10 +218,8 @@ class TFOpenAIGPTModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFOpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFOpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":

View File

@ -17,11 +17,10 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import RobertaConfig, is_tf_available
@ -192,10 +191,8 @@ class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFRobertaModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFRobertaModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)

View File

@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import T5Config, is_tf_available
@ -162,10 +161,8 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in ['t5-small']:
model = TFT5Model.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFT5Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":

View File

@ -18,11 +18,10 @@ from __future__ import print_function
import unittest
import random
import shutil
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import TransfoXLConfig, is_tf_available
@ -205,10 +204,8 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFTransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFTransfoXLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)

View File

@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_tf_available
@ -31,7 +30,7 @@ if is_tf_available():
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
@require_tf
@ -252,10 +251,8 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFXLMModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)

View File

@ -20,7 +20,6 @@ import os
import unittest
import json
import random
import shutil
from transformers import XLNetConfig, is_tf_available
@ -35,7 +34,7 @@ if is_tf_available():
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
@require_tf
@ -319,10 +318,8 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFXLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFXLNetModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)

View File

@ -18,7 +18,6 @@ from __future__ import print_function
import unittest
import random
import shutil
from transformers import is_torch_available
@ -29,7 +28,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
@ -208,10 +207,8 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TransfoXLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)

View File

@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
@ -28,7 +27,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
@ -318,10 +317,8 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = XLMModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)

View File

@ -20,7 +20,6 @@ import os
import unittest
import json
import random
import shutil
from transformers import is_torch_available
@ -33,7 +32,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
@ -385,10 +384,8 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = XLNetModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)

View File

@ -1,11 +1,14 @@
import os
import unittest
import tempfile
from distutils.util import strtobool
from transformers.file_utils import _tf_available, _torch_available
CACHE_DIR = os.path.join(tempfile.gettempdir(), "transformers_test")
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"