mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Merge pull request #2217 from aaugustin/test-parallelization
Support running tests in parallel
This commit is contained in:
commit
fae4d1c266
@ -1,9 +1,11 @@
|
||||
version: 2
|
||||
jobs:
|
||||
build_py3_torch_and_tf:
|
||||
run_tests_py3_torch_and_tf:
|
||||
working_directory: ~/transformers
|
||||
docker:
|
||||
- image: circleci/python:3.5
|
||||
environment:
|
||||
OMP_NUM_THREADS: 1
|
||||
resource_class: xlarge
|
||||
parallelism: 1
|
||||
steps:
|
||||
@ -11,49 +13,67 @@ jobs:
|
||||
- run: sudo pip install torch
|
||||
- run: sudo pip install tensorflow
|
||||
- run: sudo pip install --progress-bar off .
|
||||
- run: sudo pip install pytest codecov pytest-cov
|
||||
- run: sudo pip install pytest codecov pytest-cov pytest-xdist
|
||||
- run: sudo pip install tensorboardX scikit-learn
|
||||
- run: python -m pytest -sv ./transformers/tests/ --cov
|
||||
- run: python -m pytest -n 8 --dist=loadfile -s -v ./transformers/tests/ --cov
|
||||
- run: codecov
|
||||
build_py3_torch:
|
||||
run_tests_py3_torch:
|
||||
working_directory: ~/transformers
|
||||
docker:
|
||||
- image: circleci/python:3.5
|
||||
environment:
|
||||
OMP_NUM_THREADS: 1
|
||||
resource_class: xlarge
|
||||
parallelism: 1
|
||||
steps:
|
||||
- checkout
|
||||
- run: sudo pip install torch
|
||||
- run: sudo pip install --progress-bar off .
|
||||
- run: sudo pip install pytest codecov pytest-cov
|
||||
- run: sudo pip install pytest codecov pytest-cov pytest-xdist
|
||||
- run: sudo pip install tensorboardX scikit-learn
|
||||
- run: python -m pytest -sv ./transformers/tests/ --cov
|
||||
- run: python -m pytest -sv ./examples/
|
||||
- run: python -m pytest -n 8 --dist=loadfile -s -v ./transformers/tests/ --cov
|
||||
- run: codecov
|
||||
build_py3_tf:
|
||||
run_tests_py3_tf:
|
||||
working_directory: ~/transformers
|
||||
docker:
|
||||
- image: circleci/python:3.5
|
||||
environment:
|
||||
OMP_NUM_THREADS: 1
|
||||
resource_class: xlarge
|
||||
parallelism: 1
|
||||
steps:
|
||||
- checkout
|
||||
- run: sudo pip install tensorflow
|
||||
- run: sudo pip install --progress-bar off .
|
||||
- run: sudo pip install pytest codecov pytest-cov
|
||||
- run: sudo pip install pytest codecov pytest-cov pytest-xdist
|
||||
- run: sudo pip install tensorboardX scikit-learn
|
||||
- run: python -m pytest -sv ./transformers/tests/ --cov
|
||||
- run: python -m pytest -n 8 --dist=loadfile -s -v ./transformers/tests/ --cov
|
||||
- run: codecov
|
||||
build_py3_custom_tokenizers:
|
||||
run_tests_py3_custom_tokenizers:
|
||||
working_directory: ~/transformers
|
||||
docker:
|
||||
- image: circleci/python:3.5
|
||||
steps:
|
||||
- checkout
|
||||
- run: sudo pip install --progress-bar off .
|
||||
- run: sudo pip install pytest
|
||||
- run: sudo pip install pytest pytest-xdist
|
||||
- run: sudo pip install mecab-python3
|
||||
- run: RUN_CUSTOM_TOKENIZERS=1 python -m pytest -sv ./transformers/tests/tokenization_bert_japanese_test.py
|
||||
run_examples_py3_torch:
|
||||
working_directory: ~/transformers
|
||||
docker:
|
||||
- image: circleci/python:3.5
|
||||
environment:
|
||||
OMP_NUM_THREADS: 1
|
||||
resource_class: xlarge
|
||||
parallelism: 1
|
||||
steps:
|
||||
- checkout
|
||||
- run: sudo pip install torch
|
||||
- run: sudo pip install --progress-bar off .
|
||||
- run: sudo pip install pytest pytest-xdist
|
||||
- run: sudo pip install tensorboardX scikit-learn
|
||||
- run: python -m pytest -n 8 --dist=loadfile -s -v ./examples/
|
||||
deploy_doc:
|
||||
working_directory: ~/transformers
|
||||
docker:
|
||||
@ -66,7 +86,7 @@ jobs:
|
||||
- run: sudo pip install --progress-bar off -r docs/requirements.txt
|
||||
- run: sudo pip install --progress-bar off -r requirements.txt
|
||||
- run: ./.circleci/deploy.sh
|
||||
repository_consistency:
|
||||
check_repository_consistency:
|
||||
working_directory: ~/transformers
|
||||
docker:
|
||||
- image: circleci/python:3.5
|
||||
@ -85,9 +105,10 @@ workflows:
|
||||
version: 2
|
||||
build_and_test:
|
||||
jobs:
|
||||
- repository_consistency
|
||||
- build_py3_custom_tokenizers
|
||||
- build_py3_torch_and_tf
|
||||
- build_py3_torch
|
||||
- build_py3_tf
|
||||
- check_repository_consistency
|
||||
- run_examples_py3_torch
|
||||
- run_tests_py3_custom_tokenizers
|
||||
- run_tests_py3_torch_and_tf
|
||||
- run_tests_py3_torch
|
||||
- run_tests_py3_tf
|
||||
- deploy_doc: *workflow_filters
|
||||
|
1
setup.py
1
setup.py
@ -59,6 +59,7 @@ setup(
|
||||
"tests.*", "tests"]),
|
||||
install_requires=['numpy',
|
||||
'boto3',
|
||||
'filelock',
|
||||
'requests',
|
||||
'tqdm',
|
||||
'regex != 2019.12.17',
|
||||
|
@ -17,12 +17,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import XxxConfig, is_tf_available
|
||||
|
||||
@ -245,10 +244,8 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in ['xxx-base-uncased']:
|
||||
model = TFXxxModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,13 +17,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (XxxConfig, XxxModel, XxxForMaskedLM,
|
||||
@ -249,10 +248,8 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = XxxModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -10,10 +10,9 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import six
|
||||
import shutil
|
||||
import tempfile
|
||||
import fnmatch
|
||||
from functools import wraps
|
||||
from functools import partial, wraps
|
||||
from hashlib import sha256
|
||||
from io import open
|
||||
|
||||
@ -25,6 +24,8 @@ from tqdm.auto import tqdm
|
||||
from contextlib import contextmanager
|
||||
from . import __version__
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
try:
|
||||
@ -334,59 +335,60 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
|
||||
# If we don't have a connection (etag is None) and can't identify the file
|
||||
# try to get the last downloaded one
|
||||
if not os.path.exists(cache_path) and etag is None:
|
||||
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
|
||||
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
|
||||
matching_files = [
|
||||
file
|
||||
for file in fnmatch.filter(os.listdir(cache_dir), filename + '.*')
|
||||
if not file.endswith('.json') and not file.endswith('.lock')
|
||||
]
|
||||
if matching_files:
|
||||
cache_path = os.path.join(cache_dir, matching_files[-1])
|
||||
|
||||
if resume_download:
|
||||
incomplete_path = cache_path + '.incomplete'
|
||||
@contextmanager
|
||||
def _resumable_file_manager():
|
||||
with open(incomplete_path,'a+b') as f:
|
||||
yield f
|
||||
os.remove(incomplete_path)
|
||||
temp_file_manager = _resumable_file_manager
|
||||
if os.path.exists(incomplete_path):
|
||||
resume_size = os.stat(incomplete_path).st_size
|
||||
else:
|
||||
resume_size = 0
|
||||
else:
|
||||
temp_file_manager = tempfile.NamedTemporaryFile
|
||||
resume_size = 0
|
||||
# Prevent parallel downloads of the same file with a lock.
|
||||
lock_path = cache_path + '.lock'
|
||||
with FileLock(lock_path):
|
||||
|
||||
if etag is not None and (not os.path.exists(cache_path) or force_download):
|
||||
# Download to temporary file, then copy to cache dir once finished.
|
||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
||||
|
||||
# GET file object
|
||||
if url.startswith("s3://"):
|
||||
if resume_download:
|
||||
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
|
||||
s3_get(url, temp_file, proxies=proxies)
|
||||
if resume_download:
|
||||
incomplete_path = cache_path + '.incomplete'
|
||||
@contextmanager
|
||||
def _resumable_file_manager():
|
||||
with open(incomplete_path,'a+b') as f:
|
||||
yield f
|
||||
temp_file_manager = _resumable_file_manager
|
||||
if os.path.exists(incomplete_path):
|
||||
resume_size = os.stat(incomplete_path).st_size
|
||||
else:
|
||||
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
||||
resume_size = 0
|
||||
else:
|
||||
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
|
||||
resume_size = 0
|
||||
|
||||
# we are copying the file before closing it, so flush to avoid truncation
|
||||
temp_file.flush()
|
||||
# shutil.copyfileobj() starts at the current position, so go to the start
|
||||
temp_file.seek(0)
|
||||
if etag is not None and (not os.path.exists(cache_path) or force_download):
|
||||
# Download to temporary file, then copy to cache dir once finished.
|
||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
||||
|
||||
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
|
||||
with open(cache_path, 'wb') as cache_file:
|
||||
shutil.copyfileobj(temp_file, cache_file)
|
||||
# GET file object
|
||||
if url.startswith("s3://"):
|
||||
if resume_download:
|
||||
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
|
||||
s3_get(url, temp_file, proxies=proxies)
|
||||
else:
|
||||
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
||||
|
||||
logger.info("creating metadata file for %s", cache_path)
|
||||
meta = {'url': url, 'etag': etag}
|
||||
meta_path = cache_path + '.json'
|
||||
with open(meta_path, 'w') as meta_file:
|
||||
output_string = json.dumps(meta)
|
||||
if sys.version_info[0] == 2 and isinstance(output_string, str):
|
||||
output_string = unicode(output_string, 'utf-8') # The beauty of python 2
|
||||
meta_file.write(output_string)
|
||||
# we are copying the file before closing it, so flush to avoid truncation
|
||||
temp_file.flush()
|
||||
|
||||
logger.info("removing temp file %s", temp_file.name)
|
||||
logger.info("storing %s in cache at %s", url, cache_path)
|
||||
os.rename(temp_file.name, cache_path)
|
||||
|
||||
logger.info("creating metadata file for %s", cache_path)
|
||||
meta = {'url': url, 'etag': etag}
|
||||
meta_path = cache_path + '.json'
|
||||
with open(meta_path, 'w') as meta_file:
|
||||
output_string = json.dumps(meta)
|
||||
if sys.version_info[0] == 2 and isinstance(output_string, str):
|
||||
output_string = unicode(output_string, 'utf-8') # The beauty of python 2
|
||||
meta_file.write(output_string)
|
||||
|
||||
return cache_path
|
||||
|
@ -17,13 +17,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (AlbertConfig, AlbertModel, AlbertForMaskedLM,
|
||||
@ -230,10 +229,8 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = AlbertModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = AlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,13 +17,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (BertConfig, BertModel, BertForMaskedLM,
|
||||
@ -360,10 +359,8 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = BertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -18,7 +18,7 @@ from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import sys
|
||||
import os
|
||||
import os.path
|
||||
import shutil
|
||||
import tempfile
|
||||
import json
|
||||
@ -30,7 +30,7 @@ import logging
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@ -218,21 +218,22 @@ class CommonTestCases:
|
||||
inputs = inputs_dict['input_ids'] # Let's keep only input_ids
|
||||
|
||||
try:
|
||||
torch.jit.trace(model, inputs)
|
||||
traced_gpt2 = torch.jit.trace(model, inputs)
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
try:
|
||||
traced_gpt2 = torch.jit.trace(model, inputs)
|
||||
torch.jit.save(traced_gpt2, "traced_model.pt")
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't save module.")
|
||||
with TemporaryDirectory() as tmp_dir_name:
|
||||
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
||||
|
||||
try:
|
||||
loaded_model = torch.jit.load("traced_model.pt")
|
||||
os.remove("traced_model.pt")
|
||||
except ValueError:
|
||||
self.fail("Couldn't load module.")
|
||||
try:
|
||||
torch.jit.save(traced_gpt2, pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't save module.")
|
||||
|
||||
try:
|
||||
loaded_model = torch.jit.load(pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't load module.")
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@ -352,12 +353,11 @@ class CommonTestCases:
|
||||
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)),
|
||||
-1: [0]}
|
||||
model.prune_heads(heads_to_prune)
|
||||
directory = "pruned_model"
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
model.save_pretrained(directory)
|
||||
model = model_class.from_pretrained(directory)
|
||||
model.to(torch_device)
|
||||
|
||||
with TemporaryDirectory() as temp_dir_name:
|
||||
model.save_pretrained(temp_dir_name)
|
||||
model = model_class.from_pretrained(temp_dir_name)
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs_dict)
|
||||
@ -366,7 +366,6 @@ class CommonTestCases:
|
||||
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
|
||||
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||
|
||||
shutil.rmtree(directory)
|
||||
|
||||
def test_head_pruning_save_load_from_config_init(self):
|
||||
if not self.test_pruning:
|
||||
@ -426,14 +425,10 @@ class CommonTestCases:
|
||||
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
|
||||
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
|
||||
|
||||
directory = "pruned_model"
|
||||
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
model.save_pretrained(directory)
|
||||
model = model_class.from_pretrained(directory)
|
||||
model.to(torch_device)
|
||||
shutil.rmtree(directory)
|
||||
with TemporaryDirectory() as temp_dir_name:
|
||||
model.save_pretrained(temp_dir_name)
|
||||
model = model_class.from_pretrained(temp_dir_name)
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs_dict)
|
||||
@ -758,10 +753,8 @@ class CommonTestCases:
|
||||
[[], []])
|
||||
|
||||
def create_and_check_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]:
|
||||
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = self.base_model_class.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.parent.assertIsNotNone(model)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
|
@ -16,7 +16,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import pdb
|
||||
|
||||
from transformers import is_torch_available
|
||||
@ -27,7 +26,7 @@ if is_torch_available():
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -205,10 +204,8 @@ class CTRLModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = CTRLModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = CTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -27,7 +27,7 @@ if is_torch_available():
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -235,10 +235,8 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
# @slow
|
||||
# def test_model_from_pretrained(self):
|
||||
# cache_dir = "/tmp/transformers_test/"
|
||||
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
# shutil.rmtree(cache_dir)
|
||||
# model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
# self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,7 +17,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
@ -27,7 +26,7 @@ if is_torch_available():
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -239,10 +238,8 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = GPT2Model.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = GPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -17,7 +17,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
@ -27,7 +26,7 @@ if is_torch_available():
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -207,10 +206,8 @@ class OpenAIGPTModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -17,7 +17,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
@ -29,7 +28,7 @@ if is_torch_available():
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -199,10 +198,8 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = RobertaModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = RobertaModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -17,13 +17,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (T5Config, T5Model, T5WithLMHeadModel)
|
||||
@ -175,10 +174,8 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = T5Model.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = T5Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,12 +17,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import AlbertConfig, is_tf_available
|
||||
|
||||
@ -217,12 +216,8 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
# for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ['albert-base-uncased']:
|
||||
model = TFAlbertModel.from_pretrained(
|
||||
model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFAlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -46,11 +46,11 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ['bert-base-uncased']:
|
||||
config = AutoConfig.from_pretrained(model_name, force_download=True)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
||||
model = TFAutoModel.from_pretrained(model_name, force_download=True)
|
||||
model = TFAutoModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, TFBertModel)
|
||||
|
||||
@ -59,11 +59,11 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ['bert-base-uncased']:
|
||||
config = AutoConfig.from_pretrained(model_name, force_download=True)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
||||
model = TFAutoModelWithLMHead.from_pretrained(model_name, force_download=True)
|
||||
model = TFAutoModelWithLMHead.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, TFBertForMaskedLM)
|
||||
|
||||
@ -72,11 +72,11 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ['bert-base-uncased']:
|
||||
config = AutoConfig.from_pretrained(model_name, force_download=True)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
||||
model = TFAutoModelForSequenceClassification.from_pretrained(model_name, force_download=True)
|
||||
model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, TFBertForSequenceClassification)
|
||||
|
||||
@ -85,17 +85,17 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ['bert-base-uncased']:
|
||||
config = AutoConfig.from_pretrained(model_name, force_download=True)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
||||
model = TFAutoModelForQuestionAnswering.from_pretrained(model_name, force_download=True)
|
||||
model = TFAutoModelForQuestionAnswering.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, TFBertForQuestionAnswering)
|
||||
|
||||
def test_from_pretrained_identifier(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, force_download=True)
|
||||
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||
self.assertIsInstance(model, TFBertForMaskedLM)
|
||||
|
||||
|
||||
|
@ -17,12 +17,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import BertConfig, is_tf_available
|
||||
|
||||
@ -310,11 +309,9 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ['bert-base-uncased']:
|
||||
model = TFBertModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,12 +17,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import CTRLConfig, is_tf_available
|
||||
|
||||
@ -189,10 +188,8 @@ class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFCTRLModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFCTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -20,7 +20,7 @@ import unittest
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import DistilBertConfig, is_tf_available
|
||||
|
||||
@ -211,10 +211,8 @@ class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
# @slow
|
||||
# def test_model_from_pretrained(self):
|
||||
# cache_dir = "/tmp/transformers_test/"
|
||||
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
# shutil.rmtree(cache_dir)
|
||||
# model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
# self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,12 +17,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import GPT2Config, is_tf_available
|
||||
|
||||
@ -220,10 +219,8 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFGPT2Model.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFGPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,12 +17,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import OpenAIGPTConfig, is_tf_available
|
||||
|
||||
@ -219,10 +218,8 @@ class TFOpenAIGPTModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFOpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFOpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,11 +17,10 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import RobertaConfig, is_tf_available
|
||||
|
||||
@ -192,10 +191,8 @@ class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFRobertaModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFRobertaModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -17,12 +17,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import T5Config, is_tf_available
|
||||
|
||||
@ -162,10 +161,8 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in ['t5-small']:
|
||||
model = TFT5Model.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFT5Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -18,11 +18,10 @@ from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import random
|
||||
import shutil
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import TransfoXLConfig, is_tf_available
|
||||
|
||||
@ -205,10 +204,8 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFTransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFTransfoXLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -17,7 +17,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from transformers import is_tf_available
|
||||
|
||||
@ -31,7 +30,7 @@ if is_tf_available():
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
|
||||
@require_tf
|
||||
@ -252,10 +251,8 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFXLMModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -20,7 +20,6 @@ import os
|
||||
import unittest
|
||||
import json
|
||||
import random
|
||||
import shutil
|
||||
|
||||
from transformers import XLNetConfig, is_tf_available
|
||||
|
||||
@ -35,7 +34,7 @@ if is_tf_available():
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
|
||||
@require_tf
|
||||
@ -319,10 +318,8 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFXLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFXLNetModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -18,7 +18,6 @@ from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import random
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
@ -29,7 +28,7 @@ if is_torch_available():
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -208,10 +207,8 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TransfoXLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -17,7 +17,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
@ -28,7 +27,7 @@ if is_torch_available():
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -318,10 +317,8 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = XLMModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -20,7 +20,6 @@ import os
|
||||
import unittest
|
||||
import json
|
||||
import random
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
@ -33,7 +32,7 @@ if is_torch_available():
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -385,10 +384,8 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = XLNetModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
@ -1,11 +1,14 @@
|
||||
import os
|
||||
import unittest
|
||||
import tempfile
|
||||
|
||||
from distutils.util import strtobool
|
||||
|
||||
from transformers.file_utils import _tf_available, _torch_available
|
||||
|
||||
|
||||
CACHE_DIR = os.path.join(tempfile.gettempdir(), "transformers_test")
|
||||
|
||||
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user