diff --git a/src/transformers/pipelines/text_classification.py b/src/transformers/pipelines/text_classification.py index 0c54fe1706c..6521da098d4 100644 --- a/src/transformers/pipelines/text_classification.py +++ b/src/transformers/pipelines/text_classification.py @@ -118,12 +118,12 @@ class TextClassificationPipeline(Pipeline): postprocess_params["function_to_apply"] = function_to_apply return preprocess_params, {}, postprocess_params - def __call__(self, *args, **kwargs): + def __call__(self, inputs, **kwargs): """ Classify the text(s) given as inputs. Args: - args (`str` or `List[str]` or `Dict[str]`, or `List[Dict[str]]`): + inputs (`str` or `List[str]` or `Dict[str]`, or `List[Dict[str]]`): One or several texts to classify. In order to use text pairs for your classification, you can send a dictionary containing `{"text", "text_pair"}` keys, or a list of those. top_k (`int`, *optional*, defaults to `1`): @@ -152,10 +152,11 @@ class TextClassificationPipeline(Pipeline): If `top_k` is used, one such dictionary is returned per label. """ - result = super().__call__(*args, **kwargs) + inputs = (inputs,) + result = super().__call__(*inputs, **kwargs) # TODO try and retrieve it in a nicer way from _sanitize_parameters. _legacy = "top_k" not in kwargs - if isinstance(args[0], str) and _legacy: + if isinstance(inputs[0], str) and _legacy: # This pipeline is odd, and return a list when single item is run return [result] else: