Fixing batching pipelines on single items for ChunkPipeline (#21132)

* Fixing #20783

* Update src/transformers/pipelines/base.py

* Fixing some tests.

* Fixup.

* Remove ffmpeg dep + a bit more relaxed for bigbird QA precision.

* Better dataset.

* Prevent failing on TF.

* Better condition. We can't use `can_use_iterator` since we cannot use it
directly.
This commit is contained in:
Nicolas Patry 2023-01-16 15:04:27 +01:00 committed by GitHub
parent fa906a264b
commit 488a179ce1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 1 deletions

View File

@ -1072,6 +1072,14 @@ class Pipeline(_ScikitCompat):
)
elif is_iterable:
return self.iterate(inputs, preprocess_params, forward_params, postprocess_params)
elif self.framework == "pt" and isinstance(self, ChunkPipeline):
return next(
iter(
self.get_iterator(
[inputs], num_workers, batch_size, preprocess_params, forward_params, postprocess_params
)
)
)
else:
return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)

View File

@ -26,6 +26,7 @@ from functools import lru_cache
from pathlib import Path
from unittest import skipIf
import datasets
import numpy as np
from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
@ -965,6 +966,29 @@ class CustomPipelineTest(unittest.TestCase):
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
@require_torch
def test_chunk_pipeline_batching_single_file(self):
# Make sure we have cached the pipeline.
pipe = pipeline(model="hf-internal-testing/tiny-random-Wav2Vec2ForCTC")
ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
audio = ds[40]["audio"]["array"]
pipe = pipeline(model="hf-internal-testing/tiny-random-Wav2Vec2ForCTC")
# For some reason scoping doesn't work if not using `self.`
self.COUNT = 0
forward = pipe.model.forward
def new_forward(*args, **kwargs):
self.COUNT += 1
return forward(*args, **kwargs)
pipe.model.forward = new_forward
for out in pipe(audio, return_timestamps="char", chunk_length_s=3, stride_length_s=[1, 1], batch_size=1024):
pass
self.assertEqual(self.COUNT, 1)
@require_torch
@is_staging_test

View File

@ -106,11 +106,13 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
# Using batch is OK
if question_answerer.tokenizer.pad_token_id is None:
question_answerer.tokenizer.pad_token_id = question_answerer.model.config.eos_token_id
new_outputs = question_answerer(
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." * 20, batch_size=2
)
self.assertEqual(new_outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
self.assertEqual(outputs, new_outputs)
self.assertEqual(nested_simplify(outputs), nested_simplify(new_outputs))
@require_torch
def test_small_model_pt(self):