mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Adding pipeline task aliases. (#11247)
* Adding task aliases and adding `token-classification` and `text-classification` tasks. * Cleaning docstring.
This commit is contained in:
parent
aaaed56ffc
commit
c3fcba3219
@ -14,7 +14,7 @@
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from ..pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
|
||||
from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, PipelineDataFormat, pipeline
|
||||
from ..utils import logging
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
@ -63,7 +63,9 @@ class RunCommand(BaseTransformersCLICommand):
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
|
||||
run_parser.add_argument("--task", choices=SUPPORTED_TASKS.keys(), help="Task to run")
|
||||
run_parser.add_argument(
|
||||
"--task", choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()), help="Task to run"
|
||||
)
|
||||
run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
|
||||
run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
|
||||
run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")
|
||||
|
@ -15,7 +15,7 @@
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from ..pipelines import SUPPORTED_TASKS, Pipeline, pipeline
|
||||
from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, pipeline
|
||||
from ..utils import logging
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
@ -102,7 +102,10 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
"serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
|
||||
)
|
||||
serve_parser.add_argument(
|
||||
"--task", type=str, choices=SUPPORTED_TASKS.keys(), help="The task to run the pipeline on"
|
||||
"--task",
|
||||
type=str,
|
||||
choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()),
|
||||
help="The task to run the pipeline on",
|
||||
)
|
||||
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
|
||||
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
|
||||
|
@ -93,6 +93,10 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Register all the supported tasks here
|
||||
TASK_ALIASES = {
|
||||
"sentiment-analysis": "text-classification",
|
||||
"ner": "token-classification",
|
||||
}
|
||||
SUPPORTED_TASKS = {
|
||||
"feature-extraction": {
|
||||
"impl": FeatureExtractionPipeline,
|
||||
@ -100,7 +104,7 @@ SUPPORTED_TASKS = {
|
||||
"pt": AutoModel if is_torch_available() else None,
|
||||
"default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
|
||||
},
|
||||
"sentiment-analysis": {
|
||||
"text-classification": {
|
||||
"impl": TextClassificationPipeline,
|
||||
"tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
|
||||
"pt": AutoModelForSequenceClassification if is_torch_available() else None,
|
||||
@ -111,7 +115,7 @@ SUPPORTED_TASKS = {
|
||||
},
|
||||
},
|
||||
},
|
||||
"ner": {
|
||||
"token-classification": {
|
||||
"impl": TokenClassificationPipeline,
|
||||
"tf": TFAutoModelForTokenClassification if is_tf_available() else None,
|
||||
"pt": AutoModelForTokenClassification if is_torch_available() else None,
|
||||
@ -206,8 +210,10 @@ def check_task(task: str) -> Tuple[Dict, Any]:
|
||||
The task defining which pipeline will be returned. Currently accepted tasks are:
|
||||
|
||||
- :obj:`"feature-extraction"`
|
||||
- :obj:`"sentiment-analysis"`
|
||||
- :obj:`"ner"`
|
||||
- :obj:`"text-classification"`
|
||||
- :obj:`"sentiment-analysis"` (alias of :obj:`"text-classification")
|
||||
- :obj:`"token-classification"`
|
||||
- :obj:`"ner"` (alias of :obj:`"token-classification")
|
||||
- :obj:`"question-answering"`
|
||||
- :obj:`"fill-mask"`
|
||||
- :obj:`"summarization"`
|
||||
@ -222,6 +228,8 @@ def check_task(task: str) -> Tuple[Dict, Any]:
|
||||
|
||||
|
||||
"""
|
||||
if task in TASK_ALIASES:
|
||||
task = TASK_ALIASES[task]
|
||||
if task in SUPPORTED_TASKS:
|
||||
targeted_task = SUPPORTED_TASKS[task]
|
||||
return targeted_task, None
|
||||
@ -264,8 +272,12 @@ def pipeline(
|
||||
The task defining which pipeline will be returned. Currently accepted tasks are:
|
||||
|
||||
- :obj:`"feature-extraction"`: will return a :class:`~transformers.FeatureExtractionPipeline`.
|
||||
- :obj:`"sentiment-analysis"`: will return a :class:`~transformers.TextClassificationPipeline`.
|
||||
- :obj:`"ner"`: will return a :class:`~transformers.TokenClassificationPipeline`.
|
||||
- :obj:`"text-classification"`: will return a :class:`~transformers.TextClassificationPipeline`.
|
||||
- :obj:`"sentiment-analysis"`: (alias of :obj:`"text-classification") will return a
|
||||
:class:`~transformers.TextClassificationPipeline`.
|
||||
- :obj:`"token-classification"`: will return a :class:`~transformers.TokenClassificationPipeline`.
|
||||
- :obj:`"ner"` (alias of :obj:`"token-classification"): will return a
|
||||
:class:`~transformers.TokenClassificationPipeline`.
|
||||
- :obj:`"question-answering"`: will return a :class:`~transformers.QuestionAnsweringPipeline`.
|
||||
- :obj:`"fill-mask"`: will return a :class:`~transformers.FillMaskPipeline`.
|
||||
- :obj:`"summarization"`: will return a :class:`~transformers.SummarizationPipeline`.
|
||||
|
@ -17,7 +17,7 @@ import unittest
|
||||
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
||||
|
||||
|
||||
class SentimentAnalysisPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
class TextClassificationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
pipeline_task = "sentiment-analysis"
|
||||
small_models = [
|
||||
"sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
|
@ -27,7 +27,7 @@ if is_torch_available():
|
||||
VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]
|
||||
|
||||
|
||||
class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||
class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||
pipeline_task = "ner"
|
||||
small_models = [
|
||||
"sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
|
Loading…
Reference in New Issue
Block a user