mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Fixing batching pipelines on single items for ChunkPipeline (#21132)
* Fixing #20783 * Update src/transformers/pipelines/base.py * Fixing some tests. * Fixup. * Remove ffmpeg dep + a bit more relaxed for bigbird QA precision. * Better dataset. * Prevent failing on TF. * Better condition. We can't use `can_use_iterator` since we cannot use it directly.
This commit is contained in:
parent
fa906a264b
commit
488a179ce1
@ -1072,6 +1072,14 @@ class Pipeline(_ScikitCompat):
|
||||
)
|
||||
elif is_iterable:
|
||||
return self.iterate(inputs, preprocess_params, forward_params, postprocess_params)
|
||||
elif self.framework == "pt" and isinstance(self, ChunkPipeline):
|
||||
return next(
|
||||
iter(
|
||||
self.get_iterator(
|
||||
[inputs], num_workers, batch_size, preprocess_params, forward_params, postprocess_params
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
|
||||
|
||||
|
@ -26,6 +26,7 @@ from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from unittest import skipIf
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
|
||||
from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
|
||||
@ -965,6 +966,29 @@ class CustomPipelineTest(unittest.TestCase):
|
||||
self.assertEqual(counter.head_request_count, 1)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
|
||||
@require_torch
|
||||
def test_chunk_pipeline_batching_single_file(self):
|
||||
# Make sure we have cached the pipeline.
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-Wav2Vec2ForCTC")
|
||||
ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
audio = ds[40]["audio"]["array"]
|
||||
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-Wav2Vec2ForCTC")
|
||||
# For some reason scoping doesn't work if not using `self.`
|
||||
self.COUNT = 0
|
||||
forward = pipe.model.forward
|
||||
|
||||
def new_forward(*args, **kwargs):
|
||||
self.COUNT += 1
|
||||
return forward(*args, **kwargs)
|
||||
|
||||
pipe.model.forward = new_forward
|
||||
|
||||
for out in pipe(audio, return_timestamps="char", chunk_length_s=3, stride_length_s=[1, 1], batch_size=1024):
|
||||
pass
|
||||
|
||||
self.assertEqual(self.COUNT, 1)
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
@ -106,11 +106,13 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
|
||||
|
||||
# Using batch is OK
|
||||
if question_answerer.tokenizer.pad_token_id is None:
|
||||
question_answerer.tokenizer.pad_token_id = question_answerer.model.config.eos_token_id
|
||||
new_outputs = question_answerer(
|
||||
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." * 20, batch_size=2
|
||||
)
|
||||
self.assertEqual(new_outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
|
||||
self.assertEqual(outputs, new_outputs)
|
||||
self.assertEqual(nested_simplify(outputs), nested_simplify(new_outputs))
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
|
Loading…
Reference in New Issue
Block a user