Adding pipeline task aliases. (#11247)

* Adding task aliases and adding `token-classification` and
`text-classification` tasks.

* Cleaning docstring.
This commit is contained in:
Nicolas Patry 2021-04-15 09:51:24 +02:00 committed by GitHub
parent aaaed56ffc
commit c3fcba3219
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 29 additions and 12 deletions

View File

@ -14,7 +14,7 @@
from argparse import ArgumentParser
from ..pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, PipelineDataFormat, pipeline
from ..utils import logging
from . import BaseTransformersCLICommand
@ -63,7 +63,9 @@ class RunCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
run_parser.add_argument("--task", choices=SUPPORTED_TASKS.keys(), help="Task to run")
run_parser.add_argument(
"--task", choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()), help="Task to run"
)
run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")

View File

@ -15,7 +15,7 @@
from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional
from ..pipelines import SUPPORTED_TASKS, Pipeline, pipeline
from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, pipeline
from ..utils import logging
from . import BaseTransformersCLICommand
@ -102,7 +102,10 @@ class ServeCommand(BaseTransformersCLICommand):
"serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
)
serve_parser.add_argument(
"--task", type=str, choices=SUPPORTED_TASKS.keys(), help="The task to run the pipeline on"
"--task",
type=str,
choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()),
help="The task to run the pipeline on",
)
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")

View File

@ -93,6 +93,10 @@ logger = logging.get_logger(__name__)
# Register all the supported tasks here
TASK_ALIASES = {
"sentiment-analysis": "text-classification",
"ner": "token-classification",
}
SUPPORTED_TASKS = {
"feature-extraction": {
"impl": FeatureExtractionPipeline,
@ -100,7 +104,7 @@ SUPPORTED_TASKS = {
"pt": AutoModel if is_torch_available() else None,
"default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
},
"sentiment-analysis": {
"text-classification": {
"impl": TextClassificationPipeline,
"tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
"pt": AutoModelForSequenceClassification if is_torch_available() else None,
@ -111,7 +115,7 @@ SUPPORTED_TASKS = {
},
},
},
"ner": {
"token-classification": {
"impl": TokenClassificationPipeline,
"tf": TFAutoModelForTokenClassification if is_tf_available() else None,
"pt": AutoModelForTokenClassification if is_torch_available() else None,
@ -206,8 +210,10 @@ def check_task(task: str) -> Tuple[Dict, Any]:
The task defining which pipeline will be returned. Currently accepted tasks are:
- :obj:`"feature-extraction"`
- :obj:`"sentiment-analysis"`
- :obj:`"ner"`
- :obj:`"text-classification"`
- :obj:`"sentiment-analysis"` (alias of :obj:`"text-classification")
- :obj:`"token-classification"`
- :obj:`"ner"` (alias of :obj:`"token-classification")
- :obj:`"question-answering"`
- :obj:`"fill-mask"`
- :obj:`"summarization"`
@ -222,6 +228,8 @@ def check_task(task: str) -> Tuple[Dict, Any]:
"""
if task in TASK_ALIASES:
task = TASK_ALIASES[task]
if task in SUPPORTED_TASKS:
targeted_task = SUPPORTED_TASKS[task]
return targeted_task, None
@ -264,8 +272,12 @@ def pipeline(
The task defining which pipeline will be returned. Currently accepted tasks are:
- :obj:`"feature-extraction"`: will return a :class:`~transformers.FeatureExtractionPipeline`.
- :obj:`"sentiment-analysis"`: will return a :class:`~transformers.TextClassificationPipeline`.
- :obj:`"ner"`: will return a :class:`~transformers.TokenClassificationPipeline`.
- :obj:`"text-classification"`: will return a :class:`~transformers.TextClassificationPipeline`.
- :obj:`"sentiment-analysis"`: (alias of :obj:`"text-classification") will return a
:class:`~transformers.TextClassificationPipeline`.
- :obj:`"token-classification"`: will return a :class:`~transformers.TokenClassificationPipeline`.
- :obj:`"ner"` (alias of :obj:`"token-classification"): will return a
:class:`~transformers.TokenClassificationPipeline`.
- :obj:`"question-answering"`: will return a :class:`~transformers.QuestionAnsweringPipeline`.
- :obj:`"fill-mask"`: will return a :class:`~transformers.FillMaskPipeline`.
- :obj:`"summarization"`: will return a :class:`~transformers.SummarizationPipeline`.

View File

@ -17,7 +17,7 @@ import unittest
from .test_pipelines_common import MonoInputPipelineCommonMixin
class SentimentAnalysisPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
class TextClassificationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "sentiment-analysis"
small_models = [
"sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"

View File

@ -27,7 +27,7 @@ if is_torch_available():
VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]
class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "ner"
small_models = [
"sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"