import os import tempfile import unittest from distutils.util import strtobool from transformers.file_utils import _tf_available, _torch_available CACHE_DIR = os.path.join(tempfile.gettempdir(), "transformers_test") SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" def parse_flag_from_env(key, default=False): try: value = os.environ[key] except KeyError: # KEY isn't set, default to `default`. _value = default else: # KEY is set, convert it to True or False. try: _value = strtobool(value) except ValueError: # More values are supported, but let's keep the message simple. raise ValueError("If set, {} must be yes or no.".format(key)) return _value _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) _run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False) def slow(test_case): """ Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. """ if not _run_slow_tests: test_case = unittest.skip("test is slow")(test_case) return test_case def custom_tokenizers(test_case): """ Decorator marking a test for a custom tokenizer. Custom tokenizers require additional dependencies, and are skipped by default. Set the RUN_CUSTOM_TOKENIZERS environment variable to a truthy value to run them. """ if not _run_custom_tokenizers: test_case = unittest.skip("test of custom tokenizers")(test_case) return test_case def require_torch(test_case): """ Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed. """ if not _torch_available: test_case = unittest.skip("test requires PyTorch")(test_case) return test_case def require_tf(test_case): """ Decorator marking a test that requires TensorFlow. These tests are skipped when TensorFlow isn't installed. """ if not _tf_available: test_case = unittest.skip("test requires TensorFlow")(test_case) return test_case if _torch_available: # Set the USE_CUDA environment variable to select a GPU. torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu" else: torch_device = None