mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Adding QoL for batch_size
arg (like others enabled everywhere). (#15027)
* Adding QoL for `batch_size` arg (like others enabled everywhere). * Typo.
This commit is contained in:
parent
e34dd055e9
commit
65cb94ff77
@ -742,6 +742,8 @@ class Pipeline(_ScikitCompat):
|
||||
self.model.config.update(task_specific_params.get(task))
|
||||
|
||||
self.call_count = 0
|
||||
self._batch_size = kwargs.pop("batch_size", None)
|
||||
self._num_workers = kwargs.pop("num_workers", None)
|
||||
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
|
||||
|
||||
def save_pretrained(self, save_directory: str):
|
||||
@ -947,9 +949,21 @@ class Pipeline(_ScikitCompat):
|
||||
final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
|
||||
return final_iterator
|
||||
|
||||
def __call__(self, inputs, *args, num_workers=0, batch_size=1, **kwargs):
|
||||
def __call__(self, inputs, *args, num_workers=None, batch_size=None, **kwargs):
|
||||
if args:
|
||||
logger.warning(f"Ignoring args : {args}")
|
||||
|
||||
if num_workers is None:
|
||||
if self._num_workers is None:
|
||||
num_workers = 0
|
||||
else:
|
||||
num_workers = self._num_workers
|
||||
if batch_size is None:
|
||||
if self._batch_size is None:
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = self._batch_size
|
||||
|
||||
preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(**kwargs)
|
||||
|
||||
# Fuse __init__ params and __call__ params without modifying the __init__ ones.
|
||||
|
@ -299,6 +299,16 @@ class CommonPipelineTest(unittest.TestCase):
|
||||
|
||||
self.assertIsInstance(pipe, TextClassificationPipeline)
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_batch_size_global(self):
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-distilbert")
|
||||
self.assertEqual(pipe._batch_size, None)
|
||||
self.assertEqual(pipe._num_workers, None)
|
||||
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-distilbert", batch_size=2, num_workers=1)
|
||||
self.assertEqual(pipe._batch_size, 2)
|
||||
self.assertEqual(pipe._num_workers, 1)
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_override(self):
|
||||
class MyPipeline(TextClassificationPipeline):
|
||||
|
Loading…
Reference in New Issue
Block a user