Added download command through the cli.

It allows to predownload models and tokenizers.
This commit is contained in:
Morgan Funtowicz 2019-12-03 14:56:57 +01:00
parent 31a3a73ee3
commit 81babb227e
2 changed files with 32 additions and 1 deletions

4
transformers-cli Normal file → Executable file
View File

@ -1,6 +1,7 @@
#!/usr/bin/env python
from argparse import ArgumentParser
from transformers.commands.download import DownloadCommand
from transformers.commands.serving import ServeCommand
from transformers.commands.user import UserCommands
from transformers.commands.train import TrainCommand
@ -11,10 +12,11 @@ if __name__ == '__main__':
commands_parser = parser.add_subparsers(help='transformers-cli command helpers')
# Register commands
ConvertCommand.register_subcommand(commands_parser)
DownloadCommand.register_subcommand(commands_parser)
ServeCommand.register_subcommand(commands_parser)
UserCommands.register_subcommand(commands_parser)
TrainCommand.register_subcommand(commands_parser)
ConvertCommand.register_subcommand(commands_parser)
# Let's go
args = parser.parse_args()

View File

@ -0,0 +1,29 @@
from argparse import ArgumentParser
from transformers.commands import BaseTransformersCLICommand
def download_command_factory(args):
return DownloadCommand(args.model, args.cache_dir, args.force)
class DownloadCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
download_parser = parser.add_parser('download')
download_parser.add_argument('--cache-dir', type=str, default=None, help='Path to location to store the models')
download_parser.add_argument('--force', action='store_true', help='Force the model to be download even if already in cache-dir')
download_parser.add_argument('model', type=str, help='Name of the model to download')
download_parser.set_defaults(func=download_command_factory)
def __init__(self, model: str, cache: str, force: bool):
self._model = model
self._cache = cache
self._force = force
def run(self):
from transformers import AutoModel, AutoTokenizer
AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)