diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 971c0211b57..d3ee4e871e2 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -384,7 +384,7 @@ def get_framework(model, revision: Optional[str] = None): def get_default_model_and_revision( targeted_task: Dict, framework: Optional[str], task_options: Optional[Any] -) -> Union[str, Tuple[str, str]]: +) -> Tuple[str, str]: """ Select a default model to use for a given task. Defaults to pytorch if ambiguous. @@ -401,7 +401,9 @@ def get_default_model_and_revision( Returns - `str` The model string representing the default model for this pipeline + Tuple: + - `str` The model string representing the default model for this pipeline. + - `str` The revision of the model. """ if is_torch_available() and not is_tf_available(): framework = "pt" diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 1fec4be3d95..baee8999284 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -796,7 +796,7 @@ class CustomPipelineTest(unittest.TestCase): pipeline_class=PairClassificationPipeline, pt_model=AutoModelForSequenceClassification if is_torch_available() else None, tf_model=TFAutoModelForSequenceClassification if is_tf_available() else None, - default={"pt": "hf-internal-testing/tiny-random-distilbert"}, + default={"pt": ("hf-internal-testing/tiny-random-distilbert", "2ef615d")}, type="text", ) assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks() @@ -806,7 +806,9 @@ class CustomPipelineTest(unittest.TestCase): self.assertEqual(task_def["tf"], (TFAutoModelForSequenceClassification,) if is_tf_available() else ()) self.assertEqual(task_def["type"], "text") self.assertEqual(task_def["impl"], PairClassificationPipeline) - self.assertEqual(task_def["default"], {"model": {"pt": "hf-internal-testing/tiny-random-distilbert"}}) + self.assertEqual( + task_def["default"], {"model": {"pt": ("hf-internal-testing/tiny-random-distilbert", "2ef615d")}} + ) # Clean registry for next tests. del PIPELINE_REGISTRY.supported_tasks["custom-text-classification"]