mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Set inputs
as kwarg in TextClassificationPipeline
(#29495)
* Set `inputs` as kwarg in `TextClassificationPipeline` This change has been done to align the `TextClassificationPipeline` with the rest of the pipelines, and to be able to e.g. `pipeline(**{"inputs": "text"})` which wouldn't be possible since the `*args` were being used instead. * Add `noqa: C409` on `tuple([inputs],)` Even though is discouraged by the linter, the cast `tuple(list(...),)` is required here, as otherwise the original list in `inputs` will be transformed into a `tuple` and the elements 1...N will be ignored by the `Pipeline` * Run `ruff format` * Simplify `tuple` conversion with `(inputs,)` Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> --------- Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
parent
4ed9ae623d
commit
ddf177ee4a
@ -118,12 +118,12 @@ class TextClassificationPipeline(Pipeline):
|
||||
postprocess_params["function_to_apply"] = function_to_apply
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(self, inputs, **kwargs):
|
||||
"""
|
||||
Classify the text(s) given as inputs.
|
||||
|
||||
Args:
|
||||
args (`str` or `List[str]` or `Dict[str]`, or `List[Dict[str]]`):
|
||||
inputs (`str` or `List[str]` or `Dict[str]`, or `List[Dict[str]]`):
|
||||
One or several texts to classify. In order to use text pairs for your classification, you can send a
|
||||
dictionary containing `{"text", "text_pair"}` keys, or a list of those.
|
||||
top_k (`int`, *optional*, defaults to `1`):
|
||||
@ -152,10 +152,11 @@ class TextClassificationPipeline(Pipeline):
|
||||
|
||||
If `top_k` is used, one such dictionary is returned per label.
|
||||
"""
|
||||
result = super().__call__(*args, **kwargs)
|
||||
inputs = (inputs,)
|
||||
result = super().__call__(*inputs, **kwargs)
|
||||
# TODO try and retrieve it in a nicer way from _sanitize_parameters.
|
||||
_legacy = "top_k" not in kwargs
|
||||
if isinstance(args[0], str) and _legacy:
|
||||
if isinstance(inputs[0], str) and _legacy:
|
||||
# This pipeline is odd, and return a list when single item is run
|
||||
return [result]
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user