mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Pipeline should be agnostic (#12656)
This commit is contained in:
parent
9b3aab2cce
commit
fd41e2daf4
@ -14,6 +14,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_tf_available, is_torch_available
|
||||
from transformers.data.processors.squad import SquadExample
|
||||
from transformers.pipelines import Pipeline, QuestionAnsweringArgumentHandler, pipeline
|
||||
from transformers.testing_utils import slow
|
||||
@ -57,7 +58,7 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||
task=self.pipeline_task,
|
||||
model=model,
|
||||
tokenizer=model,
|
||||
framework="pt",
|
||||
framework="pt" if is_torch_available() else "tf",
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
for model in self.small_models
|
||||
@ -65,6 +66,7 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||
return question_answering_pipelines
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(not is_torch_available() and not is_tf_available(), "Either torch or TF must be installed.")
|
||||
def test_high_topk_small_context(self):
|
||||
self.pipeline_running_kwargs.update({"topk": 20})
|
||||
valid_inputs = [
|
||||
|
Loading…
Reference in New Issue
Block a user