diff --git a/docs/source/installation.md b/docs/source/installation.md index 528a7489ab2..062b0cb9338 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -155,6 +155,31 @@ If you expect to be downloading large volumes of models (more than 1,000) from o your CI setup, or a large-scale production deployment), please cache the model files on your end. It will be way faster, and cheaper. Feel free to contact us privately if you need any help. +### Offline mode + +It's possible to run 🤗 Transformers in a firewalled or a no-network environment. + +Setting environment variable `TRANSFORMERS_OFFLINE=1` will tell 🤗 Transformers to use local files only and will not try to look things up. + +Most likely you may want to couple this with `HF_DATASETS_OFFLINE=1` that performs the same for 🤗 Datasets if you're using the latter. + +Here is an example of how this can be used on a filesystem that is shared between a normally networked and a firewalled to the external world instances. + +On the instance with the normal network run your program which will download and cache models (and optionally datasets if you use 🤗 Datasets). For example: + +``` +python examples/seq2seq/run_seq2seq.py --model_name_or_path t5-small --dataset_name wmt16 --dataset_config ro-en ... +``` + +and then with the same filesystem you can now run the same program on a firewalled instance: +``` +HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \ +python examples/seq2seq/run_seq2seq.py --model_name_or_path t5-small --dataset_name wmt16 --dataset_config ro-en ... +``` +and it should succeed without any hanging waiting to timeout. + + + ## Do you want to run a Transformer model on a mobile device? You should check out our [swift-coreml-transformers](https://github.com/huggingface/swift-coreml-transformers) repo. diff --git a/examples/seq2seq/run_seq2seq.py b/examples/seq2seq/run_seq2seq.py index f399489e1db..db278073627 100755 --- a/examples/seq2seq/run_seq2seq.py +++ b/examples/seq2seq/run_seq2seq.py @@ -44,15 +44,22 @@ from transformers import ( default_data_collator, set_seed, ) +from transformers.file_utils import is_offline_mode from transformers.trainer_utils import get_last_checkpoint, is_main_process -with FileLock(".lock") as lock: - nltk.download("punkt", quiet=True) - - logger = logging.getLogger(__name__) +try: + nltk.data.find("tokenizers/punkt") +except LookupError: + if is_offline_mode(): + raise LookupError( + "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" + ) + with FileLock(".lock") as lock: + nltk.download("punkt", quiet=True) + @dataclass class ModelArguments: diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 0d0b410e0be..4e5de613867 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -22,7 +22,7 @@ import os from typing import Any, Dict, Tuple, Union from . import __version__ -from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url +from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_offline_mode, is_remote_url from .utils import logging @@ -412,6 +412,10 @@ class PretrainedConfig(object): local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + pretrained_model_name_or_path = str(pretrained_model_name_or_path) if os.path.isdir(pretrained_model_name_or_path): config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 8d3008d1c4c..e6309caaa74 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -234,6 +234,13 @@ PRESET_MIRROR_DICT = { } +_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False + + +def is_offline_mode(): + return _is_offline_mode + + def is_torch_available(): return _torch_available diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 40601bdf31e..9a4f421a0de 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -36,6 +36,7 @@ from .file_utils import ( ModelOutput, cached_path, hf_bucket_url, + is_offline_mode, is_remote_url, replace_return_docstrings, ) @@ -964,6 +965,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): revision = kwargs.pop("revision", None) mirror = kwargs.pop("mirror", None) + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index ebe27b6829a..aefe209b65e 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -44,6 +44,7 @@ from .file_utils import ( cached_path, hf_bucket_url, is_flax_available, + is_offline_mode, is_remote_url, is_tf_available, is_tokenizers_available, @@ -1597,6 +1598,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + s3_models = list(cls.max_model_input_sizes.keys()) pretrained_model_name_or_path = str(pretrained_model_name_or_path) vocab_files = {} diff --git a/tests/test_offline.py b/tests/test_offline.py new file mode 100644 index 00000000000..5217c5d6af5 --- /dev/null +++ b/tests/test_offline.py @@ -0,0 +1,53 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import sys + +from transformers.testing_utils import TestCasePlus, require_torch + + +class OfflineTests(TestCasePlus): + @require_torch + def test_offline_mode(self): + + # this test is a bit tricky since TRANSFORMERS_OFFLINE can only be changed before + # `transformers` is loaded, and it's too late for inside pytest - so we are changing it + # while running an external program + + # python one-liner segments + load = "from transformers import BertConfig, BertModel, BertTokenizer;" + run = "mname = 'lysandre/tiny-bert-random'; BertConfig.from_pretrained(mname) and BertModel.from_pretrained(mname) and BertTokenizer.from_pretrained(mname);" + mock = 'import socket; exec("def offline_socket(*args, **kwargs): raise socket.error(\\"Offline mode is enabled.\\")"); socket.socket = offline_socket;' + + # baseline - just load from_pretrained with normal network + cmd = [sys.executable, "-c", f"{load} {run}"] + + # should succeed + env = self.get_env() + result = subprocess.run(cmd, env=env, check=False, capture_output=True) + self.assertEqual(result.returncode, 0, result.stderr) + + # next emulate no network + cmd = [sys.executable, "-c", f"{load} {mock} {run}"] + + # should normally fail as it will fail to lookup the model files w/o the network + env["TRANSFORMERS_OFFLINE"] = "0" + result = subprocess.run(cmd, env=env, check=False, capture_output=True) + self.assertEqual(result.returncode, 1, result.stderr) + + # should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files + env["TRANSFORMERS_OFFLINE"] = "1" + result = subprocess.run(cmd, env=env, check=False, capture_output=True) + self.assertEqual(result.returncode, 0, result.stderr)