Set inputs as kwarg in TextClassificationPipeline (#29495)

* Set `inputs` as kwarg in `TextClassificationPipeline`

This change has been done to align the `TextClassificationPipeline` with the rest of the pipelines, and to be able to e.g. `pipeline(**{"inputs": "text"})` which wouldn't be possible since the `*args` were being used instead.

* Add `noqa: C409` on `tuple([inputs],)`

Even though is discouraged by the linter, the cast `tuple(list(...),)` is required here, as otherwise the original list in `inputs` will be transformed into a `tuple` and the elements 1...N will be ignored by the `Pipeline`

* Run `ruff format`

* Simplify `tuple` conversion with `(inputs,)`

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

---------

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Alvaro Bartolome 2024-03-07 21:43:57 +01:00 committed by GitHub
parent 4ed9ae623d
commit ddf177ee4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -118,12 +118,12 @@ class TextClassificationPipeline(Pipeline):
postprocess_params["function_to_apply"] = function_to_apply
return preprocess_params, {}, postprocess_params
def __call__(self, *args, **kwargs):
def __call__(self, inputs, **kwargs):
"""
Classify the text(s) given as inputs.
Args:
args (`str` or `List[str]` or `Dict[str]`, or `List[Dict[str]]`):
inputs (`str` or `List[str]` or `Dict[str]`, or `List[Dict[str]]`):
One or several texts to classify. In order to use text pairs for your classification, you can send a
dictionary containing `{"text", "text_pair"}` keys, or a list of those.
top_k (`int`, *optional*, defaults to `1`):
@ -152,10 +152,11 @@ class TextClassificationPipeline(Pipeline):
If `top_k` is used, one such dictionary is returned per label.
"""
result = super().__call__(*args, **kwargs)
inputs = (inputs,)
result = super().__call__(*inputs, **kwargs)
# TODO try and retrieve it in a nicer way from _sanitize_parameters.
_legacy = "top_k" not in kwargs
if isinstance(args[0], str) and _legacy:
if isinstance(inputs[0], str) and _legacy:
# This pipeline is odd, and return a list when single item is run
return [result]
else: