mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[WIP] - CLI
This commit is contained in:
parent
e57d00ee10
commit
72c36b9ea2
8
setup.py
8
setup.py
@ -62,15 +62,15 @@ setup(
|
||||
'regex',
|
||||
'sentencepiece',
|
||||
'sacremoses'],
|
||||
extras_require=extras,
|
||||
scripts=[
|
||||
'transformers-cli'
|
||||
],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
"transformers=transformers.__main__:main",
|
||||
]
|
||||
},
|
||||
extras_require=extras,
|
||||
scripts=[
|
||||
'transformers-cli'
|
||||
],
|
||||
# python_requires='>=3.5.0',
|
||||
classifiers=[
|
||||
'Intended Audience :: Science/Research',
|
||||
|
@ -1,14 +1,15 @@
|
||||
#!/usr/bin/env python
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from transformers.commands.serving import ServeCommand
|
||||
from transformers.commands.user import UserCommands
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = ArgumentParser(description='Transformers CLI tool', usage='transformers-cli <command> [<args>]')
|
||||
parser = ArgumentParser('Transformers CLI tool', usage='transformers-cli <command> [<args>]')
|
||||
commands_parser = parser.add_subparsers(help='transformers-cli command helpers')
|
||||
|
||||
# Register commands
|
||||
ServeCommand.register_subcommand(commands_parser)
|
||||
UserCommands.register_subcommand(commands_parser)
|
||||
|
||||
# Let's go
|
||||
|
@ -24,6 +24,8 @@ from .file_utils import (TRANSFORMERS_CACHE, PYTORCH_TRANSFORMERS_CACHE, PYTORCH
|
||||
|
||||
from .data import (is_sklearn_available,
|
||||
InputExample, InputFeatures, DataProcessor,
|
||||
SingleSentenceClassificationProcessor,
|
||||
convert_examples_to_features,
|
||||
glue_output_modes, glue_convert_examples_to_features,
|
||||
glue_processors, glue_tasks_num_labels,
|
||||
xnli_output_modes, xnli_processors, xnli_tasks_num_labels,
|
||||
|
@ -1,129 +1,36 @@
|
||||
# coding: utf8
|
||||
|
||||
def main():
|
||||
import sys
|
||||
if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet", "xlm"]:
|
||||
if len(sys.argv) < 2 or sys.argv[1] not in ["convert", "train", "predict", "serve"]:
|
||||
print(
|
||||
"This command line utility let you convert original (author released) model checkpoint to pytorch.\n"
|
||||
"It should be used as one of: \n"
|
||||
">> transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n"
|
||||
">> transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n"
|
||||
">> transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n"
|
||||
">> transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or \n"
|
||||
">> transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME] or \n"
|
||||
">> transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT")
|
||||
else:
|
||||
if sys.argv[1] == "bert":
|
||||
try:
|
||||
from .convert_bert_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
|
||||
except ImportError:
|
||||
print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
"First argument to `transformers` command line interface should be one of: \n"
|
||||
">> convert serve train predict")
|
||||
if sys.argv[1] == "convert":
|
||||
from transformers.commands import convert
|
||||
convert(sys.argv)
|
||||
elif sys.argv[1] == "train":
|
||||
from transformers.commands import train
|
||||
train(sys.argv)
|
||||
elif sys.argv[1] == "serve":
|
||||
pass
|
||||
# from argparse import ArgumentParser
|
||||
# from transformers.commands.serving import ServeCommand
|
||||
# parser = ArgumentParser('Transformers CLI tool', usage='transformers serve <command> [<args>]')
|
||||
# commands_parser = parser.add_subparsers(help='transformers-cli command helpers')
|
||||
|
||||
if len(sys.argv) != 5:
|
||||
# pylint: disable=line-too-long
|
||||
print("Should be used as `transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
|
||||
else:
|
||||
PYTORCH_DUMP_OUTPUT = sys.argv.pop()
|
||||
TF_CONFIG = sys.argv.pop()
|
||||
TF_CHECKPOINT = sys.argv.pop()
|
||||
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
|
||||
elif sys.argv[1] == "gpt":
|
||||
from .convert_openai_original_tf_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
|
||||
if len(sys.argv) < 4 or len(sys.argv) > 5:
|
||||
# pylint: disable=line-too-long
|
||||
print("Should be used as `transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`")
|
||||
else:
|
||||
OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
|
||||
PYTORCH_DUMP_OUTPUT = sys.argv[3]
|
||||
if len(sys.argv) == 5:
|
||||
OPENAI_GPT_CONFIG = sys.argv[4]
|
||||
else:
|
||||
OPENAI_GPT_CONFIG = ""
|
||||
convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
|
||||
OPENAI_GPT_CONFIG,
|
||||
PYTORCH_DUMP_OUTPUT)
|
||||
elif sys.argv[1] == "transfo_xl":
|
||||
try:
|
||||
from .convert_transfo_xl_original_tf_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
|
||||
except ImportError:
|
||||
print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
if len(sys.argv) < 4 or len(sys.argv) > 5:
|
||||
# pylint: disable=line-too-long
|
||||
print("Should be used as `transformers transfo_xl TF_CHECKPOINT/TF_DATASET_FILE PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
|
||||
else:
|
||||
if 'ckpt' in sys.argv[2].lower():
|
||||
TF_CHECKPOINT = sys.argv[2]
|
||||
TF_DATASET_FILE = ""
|
||||
else:
|
||||
TF_DATASET_FILE = sys.argv[2]
|
||||
TF_CHECKPOINT = ""
|
||||
PYTORCH_DUMP_OUTPUT = sys.argv[3]
|
||||
if len(sys.argv) == 5:
|
||||
TF_CONFIG = sys.argv[4]
|
||||
else:
|
||||
TF_CONFIG = ""
|
||||
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
|
||||
elif sys.argv[1] == "gpt2":
|
||||
try:
|
||||
from .convert_gpt2_original_tf_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
|
||||
except ImportError:
|
||||
print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
# # Register commands
|
||||
# ServeCommand.register_subcommand(commands_parser)
|
||||
|
||||
if len(sys.argv) < 4 or len(sys.argv) > 5:
|
||||
# pylint: disable=line-too-long
|
||||
print("Should be used as `transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
|
||||
else:
|
||||
TF_CHECKPOINT = sys.argv[2]
|
||||
PYTORCH_DUMP_OUTPUT = sys.argv[3]
|
||||
if len(sys.argv) == 5:
|
||||
TF_CONFIG = sys.argv[4]
|
||||
else:
|
||||
TF_CONFIG = ""
|
||||
convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
|
||||
elif sys.argv[1] == "xlnet":
|
||||
try:
|
||||
from .convert_xlnet_original_tf_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch
|
||||
except ImportError:
|
||||
print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
# # Let's go
|
||||
# args = parser.parse_args()
|
||||
|
||||
if len(sys.argv) < 5 or len(sys.argv) > 6:
|
||||
# pylint: disable=line-too-long
|
||||
print("Should be used as `transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`")
|
||||
else:
|
||||
TF_CHECKPOINT = sys.argv[2]
|
||||
TF_CONFIG = sys.argv[3]
|
||||
PYTORCH_DUMP_OUTPUT = sys.argv[4]
|
||||
if len(sys.argv) == 6:
|
||||
FINETUNING_TASK = sys.argv[5]
|
||||
else:
|
||||
FINETUNING_TASK = None
|
||||
|
||||
convert_xlnet_checkpoint_to_pytorch(TF_CHECKPOINT,
|
||||
TF_CONFIG,
|
||||
PYTORCH_DUMP_OUTPUT,
|
||||
FINETUNING_TASK)
|
||||
elif sys.argv[1] == "xlm":
|
||||
from .convert_xlm_original_pytorch_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch
|
||||
|
||||
if len(sys.argv) != 4:
|
||||
# pylint: disable=line-too-long
|
||||
print("Should be used as `transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT`")
|
||||
else:
|
||||
XLM_CHECKPOINT_PATH = sys.argv[2]
|
||||
PYTORCH_DUMP_OUTPUT = sys.argv[3]
|
||||
|
||||
convert_xlm_checkpoint_to_pytorch(XLM_CHECKPOINT_PATH, PYTORCH_DUMP_OUTPUT)
|
||||
# if not hasattr(args, 'func'):
|
||||
# parser.print_help()
|
||||
# exit(1)
|
||||
# # Run
|
||||
# service = args.func(args)
|
||||
# service.run()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
115
transformers/commands/convert.py
Normal file
115
transformers/commands/convert.py
Normal file
@ -0,0 +1,115 @@
|
||||
from argparse import ArgumentParser, Namespace
|
||||
|
||||
from logging import getLogger
|
||||
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
|
||||
|
||||
def convert_command_factory(args: Namespace):
|
||||
"""
|
||||
Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint.
|
||||
:return: ServeCommand
|
||||
"""
|
||||
return ConvertCommand(args.model_type, args.tf_checkpoint, args.pytorch_dump_output,
|
||||
args.config, args.finetuning_task_name)
|
||||
|
||||
|
||||
class ConvertCommand(BaseTransformersCLICommand):
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
"""
|
||||
Register this command to argparse so it's available for the transformer-cli
|
||||
:param parser: Root parser to register command-specific arguments
|
||||
:return:
|
||||
"""
|
||||
train_parser = parser.add_parser('convert', help="CLI tool to run convert model from original "
|
||||
"author checkpoints to Transformesr PyTorch checkpoints.")
|
||||
train_parser.add_argument('--model_type', type=str, required=True,
|
||||
help='Model\'s type.')
|
||||
train_parser.add_argument('--tf_checkpoint', type=str, required=True,
|
||||
help='TensorFlow checkpoint path or folder.')
|
||||
train_parser.add_argument('--pytorch_dump_output', type=str, required=True,
|
||||
help='Path to the PyTorch savd model output.')
|
||||
train_parser.add_argument('--config', type=str, default="",
|
||||
help='Configuration file path or folder.')
|
||||
train_parser.add_argument('--finetuning_task_name', type=str, default=None,
|
||||
help='Optional fine-tuning task name if the TF model was a finetuned model.')
|
||||
train_parser.set_defaults(func=convert_command_factory)
|
||||
|
||||
def __init__(self, model_type: str, tf_checkpoint: str, pytorch_dump_output: str,
|
||||
config: str, finetuning_task_name: str, *args):
|
||||
self._logger = getLogger('transformers-cli/converting')
|
||||
|
||||
self._logger.info('Loading model {}'.format(model_type))
|
||||
self._model_type = model_type
|
||||
self._tf_checkpoint = tf_checkpoint
|
||||
self._pytorch_dump_output = pytorch_dump_output
|
||||
self._config = config
|
||||
self._finetuning_task_name = finetuning_task_name
|
||||
|
||||
def run(self):
|
||||
if self._model_type == "bert":
|
||||
try:
|
||||
from transformers.convert_bert_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
|
||||
except ImportError:
|
||||
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \
|
||||
"In that case, it requires TensorFlow to be installed. Please see " \
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
raise ImportError(msg)
|
||||
|
||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "gpt":
|
||||
from transformers.convert_openai_original_tf_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
|
||||
convert_openai_checkpoint_to_pytorch(self._tf_checkpoint,
|
||||
self._config,
|
||||
self._pytorch_dump_output)
|
||||
elif self._model_type == "transfo_xl":
|
||||
try:
|
||||
from transformers.convert_transfo_xl_original_tf_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
|
||||
except ImportError:
|
||||
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \
|
||||
"In that case, it requires TensorFlow to be installed. Please see " \
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
raise ImportError(msg)
|
||||
|
||||
if 'ckpt' in self._tf_checkpoint.lower():
|
||||
TF_CHECKPOINT = self._tf_checkpoint
|
||||
TF_DATASET_FILE = ""
|
||||
else:
|
||||
TF_DATASET_FILE = self._tf_checkpoint
|
||||
TF_CHECKPOINT = ""
|
||||
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT,
|
||||
self._config,
|
||||
self._pytorch_dump_output,
|
||||
TF_DATASET_FILE)
|
||||
elif self._model_type == "gpt2":
|
||||
try:
|
||||
from transformers.convert_gpt2_original_tf_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
|
||||
except ImportError:
|
||||
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \
|
||||
"In that case, it requires TensorFlow to be installed. Please see " \
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
raise ImportError(msg)
|
||||
|
||||
convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "xlnet":
|
||||
try:
|
||||
from transformers.convert_xlnet_original_tf_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch
|
||||
except ImportError:
|
||||
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \
|
||||
"In that case, it requires TensorFlow to be installed. Please see " \
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
raise ImportError(msg)
|
||||
|
||||
convert_xlnet_checkpoint_to_pytorch(self._tf_checkpoint,
|
||||
self._config,
|
||||
self._pytorch_dump_output,
|
||||
self._finetuning_task_name)
|
||||
elif self._model_type == "xlm":
|
||||
from transformers.convert_xlm_original_pytorch_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch
|
||||
|
||||
convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
|
||||
else:
|
||||
raise ValueError("--model_type should be selected in the list [bert, gpt, gpt2, transfo_xl, xlnet, xlm]")
|
176
transformers/commands/serving.py
Normal file
176
transformers/commands/serving.py
Normal file
@ -0,0 +1,176 @@
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from typing import List, Optional, Union, Any
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException, Body
|
||||
from logging import getLogger
|
||||
|
||||
from pydantic import BaseModel
|
||||
from uvicorn import run
|
||||
|
||||
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
|
||||
|
||||
def serve_command_factory(args: Namespace):
|
||||
"""
|
||||
Factory function used to instantiate serving server from provided command line arguments.
|
||||
:return: ServeCommand
|
||||
"""
|
||||
return ServeCommand(args.host, args.port, args.model, args.graphql)
|
||||
|
||||
|
||||
class ServeResult(BaseModel):
|
||||
"""
|
||||
Base class for serving result
|
||||
"""
|
||||
model: str
|
||||
|
||||
|
||||
class ServeModelInfoResult(ServeResult):
|
||||
"""
|
||||
Expose model information
|
||||
"""
|
||||
infos: dict
|
||||
|
||||
|
||||
class ServeTokenizeResult(ServeResult):
|
||||
"""
|
||||
Tokenize result model
|
||||
"""
|
||||
tokens: List[str]
|
||||
tokens_ids: Optional[List[int]]
|
||||
|
||||
|
||||
class ServeDeTokenizeResult(ServeResult):
|
||||
"""
|
||||
DeTokenize result model
|
||||
"""
|
||||
text: str
|
||||
|
||||
|
||||
class ServeForwardResult(ServeResult):
|
||||
"""
|
||||
Forward result model
|
||||
"""
|
||||
tokens: List[str]
|
||||
tokens_ids: List[int]
|
||||
output: Any
|
||||
|
||||
|
||||
class ServeCommand(BaseTransformersCLICommand):
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
"""
|
||||
Register this command to argparse so it's available for the transformer-cli
|
||||
:param parser: Root parser to register command-specific arguments
|
||||
:return:
|
||||
"""
|
||||
serve_parser = parser.add_parser('serve', help='CLI tool to run inference requests through REST and GraphQL endpoints.')
|
||||
serve_parser.add_argument('--host', type=str, default='localhost', help='Interface the server will listen on.')
|
||||
serve_parser.add_argument('--port', type=int, default=8888, help='Port the serving will listen to.')
|
||||
serve_parser.add_argument('--model', type=str, required=True, help='Model\'s name or path to stored model to infer from.')
|
||||
serve_parser.add_argument('--graphql', action='store_true', default=False, help='Enable GraphQL endpoints.')
|
||||
serve_parser.set_defaults(func=serve_command_factory)
|
||||
|
||||
def __init__(self, host: str, port: int, model: str, graphql: bool):
|
||||
self._logger = getLogger('transformers-cli/serving')
|
||||
|
||||
self._logger.info('Loading model {}'.format(model))
|
||||
self._model_name = model
|
||||
self._model = AutoModel.from_pretrained(model)
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
|
||||
self._logger.info('Serving model over {}:{}'.format(host, port))
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._app = FastAPI()
|
||||
|
||||
# Register routes
|
||||
self._app.add_api_route('/', self.model_info, response_model=ServeModelInfoResult, methods=['GET'])
|
||||
self._app.add_api_route('/tokenize', self.tokenize, response_model=ServeTokenizeResult, methods=['POST'])
|
||||
self._app.add_api_route('/detokenize', self.detokenize, response_model=ServeDeTokenizeResult, methods=['POST'])
|
||||
self._app.add_api_route('/forward', self.forward, response_model=ServeForwardResult, methods=['POST'])
|
||||
|
||||
def run(self):
|
||||
run(self._app, host=self._host, port=self._port)
|
||||
|
||||
def model_info(self):
|
||||
return ServeModelInfoResult(model=self._model_name, infos=vars(self._model.config))
|
||||
|
||||
def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
|
||||
"""
|
||||
Tokenize the provided input and eventually returns corresponding tokens id:
|
||||
- **text_input**: String to tokenize
|
||||
- **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer mapping.
|
||||
"""
|
||||
try:
|
||||
tokens_txt = self._tokenizer.tokenize(text_input)
|
||||
|
||||
if return_ids:
|
||||
tokens_ids = self._tokenizer.convert_tokens_to_ids(tokens_txt)
|
||||
return ServeTokenizeResult(model=self._model_name, tokens=tokens_txt, tokens_ids=tokens_ids)
|
||||
else:
|
||||
return ServeTokenizeResult(model=self._model_name, tokens=tokens_txt)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail={"model": self._model_name, "error": str(e)})
|
||||
|
||||
def detokenize(self, tokens_ids: List[int] = Body(None, embed=True),
|
||||
skip_special_tokens: bool = Body(False, embed=True),
|
||||
cleanup_tokenization_spaces: bool = Body(True, embed=True)):
|
||||
"""
|
||||
Detokenize the provided tokens ids to readable text:
|
||||
- **tokens_ids**: List of tokens ids
|
||||
- **skip_special_tokens**: Flag indicating to not try to decode special tokens
|
||||
- **cleanup_tokenization_spaces**: Flag indicating to remove all leading/trailing spaces and intermediate ones.
|
||||
"""
|
||||
try:
|
||||
decoded_str = self._tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
|
||||
return ServeDeTokenizeResult(model=self._model_name, text=decoded_str)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail={"model": self._model_name, "error": str(e)})
|
||||
|
||||
def forward(self, inputs: Union[str, List[str], List[int]] = Body(None, embed=True),
|
||||
attention_mask: Optional[List[int]] = Body(None, embed=True),
|
||||
tokens_type_ids: Optional[List[int]] = Body(None, embed=True)):
|
||||
"""
|
||||
**inputs**:
|
||||
**attention_mask**:
|
||||
**tokens_type_ids**:
|
||||
"""
|
||||
|
||||
# Check we don't have empty string
|
||||
if len(inputs) == 0:
|
||||
return ServeForwardResult(model=self._model_name, output=[], attention=[])
|
||||
|
||||
if isinstance(inputs, str):
|
||||
inputs_tokens = self._tokenizer.tokenize(inputs)
|
||||
inputs_ids = self._tokenizer.convert_tokens_to_ids(inputs_tokens)
|
||||
|
||||
elif isinstance(inputs, List):
|
||||
if isinstance(inputs[0], str):
|
||||
inputs_tokens = inputs
|
||||
inputs_ids = self._tokenizer.convert_tokens_to_ids(inputs_tokens)
|
||||
elif isinstance(inputs[0], int):
|
||||
inputs_tokens = []
|
||||
inputs_ids = inputs
|
||||
else:
|
||||
error_msg = "inputs should be string, [str] of [int] (got {})".format(type(inputs[0]))
|
||||
raise HTTPException(423, detail={"error": error_msg})
|
||||
else:
|
||||
error_msg = "inputs should be string, [str] of [int] (got {})".format(type(inputs))
|
||||
raise HTTPException(423, detail={"error": error_msg})
|
||||
|
||||
try:
|
||||
# Forward through the model
|
||||
t_input_ids = torch.tensor(inputs_ids).unsqueeze(0)
|
||||
output = self._model(t_input_ids, attention_mask, tokens_type_ids)
|
||||
|
||||
return ServeForwardResult(
|
||||
model=self._model_name, tokens=inputs_tokens,
|
||||
tokens_ids=inputs_ids, output=output[0].tolist()
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(500, {"error": str(e)})
|
121
transformers/commands/train.py
Normal file
121
transformers/commands/train.py
Normal file
@ -0,0 +1,121 @@
|
||||
from argparse import ArgumentParser, Namespace
|
||||
|
||||
from logging import getLogger
|
||||
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from transformers import (AutoTokenizer, is_tf_available, is_torch_available,
|
||||
SingleSentenceClassificationProcessor,
|
||||
convert_examples_to_features)
|
||||
if is_tf_available():
|
||||
from transformers import TFAutoModelForSequenceClassification as SequenceClassifModel
|
||||
elif is_torch_available():
|
||||
from transformers import AutoModelForSequenceClassification as SequenceClassifModel
|
||||
else:
|
||||
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
|
||||
|
||||
# TF training parameters
|
||||
BATCH_SIZE = 32
|
||||
EVAL_BATCH_SIZE = BATCH_SIZE * 2
|
||||
USE_XLA = False
|
||||
USE_AMP = False
|
||||
|
||||
def train_command_factory(args: Namespace):
|
||||
"""
|
||||
Factory function used to instantiate serving server from provided command line arguments.
|
||||
:return: ServeCommand
|
||||
"""
|
||||
return TrainCommand(args.model)
|
||||
|
||||
|
||||
class TrainCommand(BaseTransformersCLICommand):
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
"""
|
||||
Register this command to argparse so it's available for the transformer-cli
|
||||
:param parser: Root parser to register command-specific arguments
|
||||
:return:
|
||||
"""
|
||||
train_parser = parser.add_parser('train', help='CLI tool to train a model on a task.')
|
||||
train_parser.add_argument('--train_data', type=str, required=True,
|
||||
help='path to train (and optionally evaluation) dataset.')
|
||||
train_parser.add_argument('--task', type=str, default='text_classification',
|
||||
help='Task to train the model on.')
|
||||
train_parser.add_argument('--model', type=str, default='bert-base-uncased',
|
||||
help='Model\'s name or path to stored model.')
|
||||
train_parser.add_argument('--valid_data', type=str, default='',
|
||||
help='path to validation dataset.')
|
||||
train_parser.add_argument('--valid_data_ratio', type=float, default=0.1,
|
||||
help="if validation dataset is not provided, fraction of train dataset "
|
||||
"to use as validation dataset.")
|
||||
train_parser.set_defaults(func=train_command_factory)
|
||||
|
||||
def __init__(self, model_name: str, task: str, train_data: str,
|
||||
valid_data: str, valid_data_ratio: float):
|
||||
self._logger = getLogger('transformers-cli/training')
|
||||
|
||||
self._framework = 'tf' if is_tf_available() else 'torch'
|
||||
|
||||
self._logger.info('Loading model {}'.format(model_name))
|
||||
self._model_name = model_name
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
if task == 'text_classification':
|
||||
self._model = SequenceClassifModel.from_pretrained(model_name)
|
||||
elif task == 'token_classification':
|
||||
raise NotImplementedError
|
||||
elif task == 'question_answering':
|
||||
raise NotImplementedError
|
||||
|
||||
dataset = SingleSentenceClassificationProcessor.create_from_csv(train_data)
|
||||
num_data_samples = len(SingleSentenceClassificationProcessor)
|
||||
if valid_data:
|
||||
self._train_dataset = dataset
|
||||
self._num_train_samples = num_data_samples
|
||||
self._valid_dataset = SingleSentenceClassificationProcessor.create_from_csv(valid_data)
|
||||
self._num_valid_samples = len(self._valid_dataset)
|
||||
else:
|
||||
assert 0.0 < valid_data_ratio < 1.0, "--valid_data_ratio should be between 0.0 and 1.0"
|
||||
self._num_valid_samples = num_data_samples * valid_data_ratio
|
||||
self._num_train_samples = num_data_samples - self._num_valid_samples
|
||||
self._train_dataset = dataset[self._num_train_samples]
|
||||
self._valid_dataset = dataset[self._num_valid_samples]
|
||||
|
||||
def run(self):
|
||||
if self._framework == 'tf':
|
||||
return self.run_tf()
|
||||
return self.run_torch()
|
||||
|
||||
def run_torch(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def run_tf(self):
|
||||
import tensorflow as tf
|
||||
|
||||
tf.config.optimizer.set_jit(USE_XLA)
|
||||
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})
|
||||
|
||||
# Prepare dataset as a tf.train_data.Dataset instance
|
||||
train_dataset = convert_examples_to_features(self._train_dataset, self._tokenizer, mode='sequence_classification')
|
||||
valid_dataset = convert_examples_to_features(self._valid_dataset, self._tokenizer, mode='sequence_classification')
|
||||
train_dataset = train_dataset.shuffle(128).batch(BATCH_SIZE).repeat(-1)
|
||||
valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE)
|
||||
|
||||
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule
|
||||
opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
|
||||
if USE_AMP:
|
||||
# loss scaling is currently required when using mixed precision
|
||||
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
|
||||
model.compile(optimizer=opt, loss=loss, metrics=[metric])
|
||||
|
||||
# Train and evaluate using tf.keras.Model.fit()
|
||||
train_steps = train_examples//BATCH_SIZE
|
||||
valid_steps = valid_examples//EVAL_BATCH_SIZE
|
||||
|
||||
history = model.fit(train_dataset, epochs=2, steps_per_epoch=train_steps,
|
||||
validation_data=valid_dataset, validation_steps=valid_steps)
|
||||
|
||||
# Save TF2 model
|
||||
os.makedirs('./save/', exist_ok=True)
|
||||
model.save_pretrained('./save/')
|
@ -1,4 +1,4 @@
|
||||
from .processors import InputExample, InputFeatures, DataProcessor, SquadFeatures
|
||||
from .processors import InputExample, InputFeatures, DataProcessor, SquadFeatures, SingleSentenceClassificationProcessor
|
||||
from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
|
||||
from .processors import squad_convert_examples_to_features, SquadExample, SquadV1Processor, SquadV2Processor
|
||||
from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .utils import InputExample, InputFeatures, DataProcessor
|
||||
from .utils import InputExample, InputFeatures, DataProcessor, SingleSentenceClassificationProcessor, convert_examples_to_features
|
||||
from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
|
||||
from .squad import squad_convert_examples_to_features, SquadFeatures, SquadExample, SquadV1Processor, SquadV2Processor
|
||||
from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
|
@ -125,3 +125,185 @@ class DataProcessor(object):
|
||||
line = list(unicode(cell, 'utf-8') for cell in line)
|
||||
lines.append(line)
|
||||
return lines
|
||||
|
||||
|
||||
class SingleSentenceClassificationProcessor(DataProcessor):
|
||||
""" Generic processor for a single sentence classification data set."""
|
||||
def __init__(self, labels=None, examples=None):
|
||||
self.labels = [] if labels is None else labels
|
||||
self.examples = [] if examples is None else examples
|
||||
|
||||
@classmethod
|
||||
def create_from_csv(cls, file_name):
|
||||
processor = cls()
|
||||
processor.add_examples_from_csv(file_name)
|
||||
return processor
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, slice):
|
||||
return SingleSentenceClassificationProcessor(labels=self.labels,
|
||||
examples=self.examples[idx])
|
||||
return self.examples[idx]
|
||||
|
||||
def get_labels(self):
|
||||
"""Gets the list of labels for this data set."""
|
||||
return self.labels
|
||||
|
||||
def add_examples_from_csv(self, file_name):
|
||||
lines = self._read_tsv(file_name)
|
||||
self.add_examples_from_lines(lines)
|
||||
|
||||
def add_examples_from_lines(self, lines, split_name='', overwrite_labels=False, overwrite_examples=False):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
added_labels = set()
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if len(line) > 2:
|
||||
guid = "%s-%s" % (split_name, line[0]) if split_name else line[0]
|
||||
label = line[1]
|
||||
text_a = line[2]
|
||||
else:
|
||||
guid = "%s-%s" % (split_name, i) if split_name else "%s" % i
|
||||
label = line[0]
|
||||
text_a = line[1]
|
||||
|
||||
added_labels.add(label)
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
||||
|
||||
# Update examples
|
||||
if overwrite_examples:
|
||||
self.examples = examples
|
||||
else:
|
||||
self.examples.extend(examples)
|
||||
|
||||
# Update labels
|
||||
if overwrite_labels:
|
||||
self.labels = list(added_labels)
|
||||
else:
|
||||
self.labels = list(set(self.labels).union(added_labels))
|
||||
|
||||
return self.examples
|
||||
|
||||
|
||||
def convert_examples_to_features(examples, tokenizer,
|
||||
mode='sequence_classification',
|
||||
max_length=512,
|
||||
pad_on_left=False,
|
||||
pad_token=0,
|
||||
pad_token_segment_id=0,
|
||||
mask_padding_with_zero=True):
|
||||
"""
|
||||
Loads a data file into a list of ``InputFeatures``
|
||||
|
||||
Args:
|
||||
examples: List of ``InputExamples`` or ``tf.data.Dataset`` containing the examples.
|
||||
tokenizer: Instance of a tokenizer that will tokenize the examples
|
||||
max_length: Maximum example length
|
||||
task: GLUE task
|
||||
label_list: List of labels. Can be obtained from the processor using the ``processor.get_labels()`` method
|
||||
output_mode: String indicating the output mode. Either ``regression`` or ``classification``
|
||||
pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default)
|
||||
pad_token: Padding token
|
||||
pad_token_segment_id: The segment ID for the padding token (It is usually 0, but can vary such as for XLNet where it is 4)
|
||||
mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values
|
||||
and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for
|
||||
actual values)
|
||||
|
||||
Returns:
|
||||
If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset``
|
||||
containing the task-specific features. If the input is a list of ``InputExamples``, will return
|
||||
a list of task-specific ``InputFeatures`` which can be fed to the model.
|
||||
|
||||
"""
|
||||
is_tf_dataset = False
|
||||
if is_tf_available() and isinstance(examples, tf.data.Dataset):
|
||||
is_tf_dataset = True
|
||||
|
||||
if task is not None:
|
||||
processor = glue_processors[task]()
|
||||
if label_list is None:
|
||||
label_list = processor.get_labels()
|
||||
logger.info("Using label list %s for task %s" % (label_list, task))
|
||||
if output_mode is None:
|
||||
output_mode = glue_output_modes[task]
|
||||
logger.info("Using output mode %s for task %s" % (output_mode, task))
|
||||
|
||||
label_map = {label: i for i, label in enumerate(label_list)}
|
||||
|
||||
features = []
|
||||
for (ex_index, example) in enumerate(examples):
|
||||
if ex_index % 10000 == 0:
|
||||
logger.info("Writing example %d" % (ex_index))
|
||||
if is_tf_dataset:
|
||||
example = processor.get_example_from_tensor_dict(example)
|
||||
|
||||
inputs = tokenizer.encode_plus(
|
||||
example.text_a,
|
||||
example.text_b,
|
||||
add_special_tokens=True,
|
||||
max_length=max_length,
|
||||
)
|
||||
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
|
||||
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||
# tokens are attended to.
|
||||
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
||||
|
||||
# Zero-pad up to the sequence length.
|
||||
padding_length = max_length - len(input_ids)
|
||||
if pad_on_left:
|
||||
input_ids = ([pad_token] * padding_length) + input_ids
|
||||
attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
|
||||
token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids
|
||||
else:
|
||||
input_ids = input_ids + ([pad_token] * padding_length)
|
||||
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
|
||||
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
|
||||
|
||||
assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
|
||||
assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length)
|
||||
assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids), max_length)
|
||||
|
||||
if output_mode == "classification":
|
||||
label = label_map[example.label]
|
||||
elif output_mode == "regression":
|
||||
label = float(example.label)
|
||||
else:
|
||||
raise KeyError(output_mode)
|
||||
|
||||
if ex_index < 5:
|
||||
logger.info("*** Example ***")
|
||||
logger.info("guid: %s" % (example.guid))
|
||||
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
||||
logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
|
||||
logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids]))
|
||||
logger.info("label: %s (id = %d)" % (example.label, label))
|
||||
|
||||
features.append(
|
||||
InputFeatures(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
label=label))
|
||||
|
||||
if is_tf_available() and is_tf_dataset:
|
||||
def gen():
|
||||
for ex in features:
|
||||
yield ({'input_ids': ex.input_ids,
|
||||
'attention_mask': ex.attention_mask,
|
||||
'token_type_ids': ex.token_type_ids},
|
||||
ex.label)
|
||||
|
||||
return tf.data.Dataset.from_generator(gen,
|
||||
({'input_ids': tf.int32,
|
||||
'attention_mask': tf.int32,
|
||||
'token_type_ids': tf.int32},
|
||||
tf.int64),
|
||||
({'input_ids': tf.TensorShape([None]),
|
||||
'attention_mask': tf.TensorShape([None]),
|
||||
'token_type_ids': tf.TensorShape([None])},
|
||||
tf.TensorShape([])))
|
||||
|
||||
return features
|
||||
|
Loading…
Reference in New Issue
Block a user