mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
ChunkPipeline (batch_size enabled on zero-cls
and qa
pipelines. (#14225)
* Pipeline chunks. * Batching for Chunking pipelines ? * Batching for `question-answering` and `zero-shot-cls`. * Fixing for FNet. * Making ASR a chunk pipeline. * Chunking ASR API. * doc style. * Fixing ASR test. * Fixing QA eror (p_mask, padding is 1, not 0). * Enable both vad and simple chunking. * Max length for vad. * remove inference mode, crashing on s2t. * Revert ChunkPipeline for ASRpipeline. Too many knobs for simple integration within the pipeline, better stick to external convenience functions instead, more control to be had, simpler pipeline and also easier to replace with other things later. * Drop necessity for PT for these. * Enabling generators. * Add mic + cleanup. * Typo. * Typo2. * Remove ASR work, it does not belong in this PR anymore. * Update src/transformers/pipelines/pt_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/pipelines/zero_shot_classification.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Adding many comments. * Doc quality. * `hidden_states` handling. * Adding doc. * Bad rebase. * Autofixing docs. * Fixing CRITICAL bug in the new Zerocls pipeline. Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
705ca7f21b
commit
b058490ceb
@ -93,12 +93,36 @@ for out in tqdm.tqdm(pipe(KeyDataset(dataset, "file"))):
|
||||
# ....
|
||||
```
|
||||
|
||||
For ease of use, a generator is also possible:
|
||||
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline("text-classification")
|
||||
|
||||
def data():
|
||||
while True:
|
||||
# This could come from a dataset, a database, a queue or HTTP request
|
||||
# in a server
|
||||
# Caveat: because this is iterative, you cannot use `num_workers > 1` variable
|
||||
# to use multiple threads to preprocess data. You can still have 1 thread that
|
||||
# does the preprocessing while the main runs the big inference
|
||||
yield "This is a test"
|
||||
|
||||
for out in pipe(data()):
|
||||
print(out)
|
||||
# {"text": "NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD NIGHT HUSBAND"}
|
||||
# {"text": ....}
|
||||
# ....
|
||||
```
|
||||
|
||||
[[autodoc]] pipeline
|
||||
|
||||
## Pipeline batching
|
||||
|
||||
All pipelines (except *zero-shot-classification* and *question-answering* currently) can use batching. This will work
|
||||
whenever the pipeline uses its streaming ability (so when passing lists or `Dataset`).
|
||||
All pipelines can use batching. This will work
|
||||
whenever the pipeline uses its streaming ability (so when passing lists or `Dataset` or `generator`).
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
@ -120,7 +144,7 @@ for out in pipe(KeyDataset(dataset, "text"), batch_size=8, truncation="only_firs
|
||||
However, this is not automatically a win for performance. It can be either a 10x speedup or 5x slowdown depending
|
||||
on hardware, data and the actual model being used.
|
||||
|
||||
Example where it's most a speedup:
|
||||
Example where it's mostly a speedup:
|
||||
|
||||
</Tip>
|
||||
|
||||
@ -227,6 +251,39 @@ For users, a rule of thumb is:
|
||||
- The larger the GPU the more likely batching is going to be more interesting
|
||||
- As soon as you enable batching, make sure you can handle OOMs nicely.
|
||||
|
||||
## Pipeline chunk batching
|
||||
|
||||
`zero-shot-classification` and `question-answering` are slightly specific in the sense, that a single input might yield
|
||||
mutliple forward pass of a model. Under normal circumstances, this would yield issues with `batch_size` argument.
|
||||
|
||||
In order to circumvent this issue, both of these pipelines are a bit specific, they are `ChunkPipeline` instead of
|
||||
regular `Pipeline`. In short:
|
||||
|
||||
|
||||
```python
|
||||
preprocessed = pipe.preprocess(inputs)
|
||||
model_outputs = pipe.forward(preprocessed)
|
||||
outputs = pipe.postprocess(model_ouputs)
|
||||
```
|
||||
|
||||
Now becomes:
|
||||
|
||||
|
||||
```python
|
||||
all_model_outputs = []
|
||||
for preprocessed in pipe.preprocess(inputs):
|
||||
model_outputs = pipe.forward(preprocessed)
|
||||
all_model_outputs.append(model_outputs)
|
||||
outputs = pipe.postprocess(all_model_ouputs)
|
||||
```
|
||||
|
||||
This should be very transparent to your code because the pipelines are used in
|
||||
the same way.
|
||||
|
||||
This is a simplified view, since the pipeline can handle automatically the batch to ! Meaning you don't have to care
|
||||
about how many forward passes you inputs are actually going to trigger, you can optimize the `batch_size`
|
||||
independantly of the inputs. The caveats from the previous section still apply.
|
||||
|
||||
## Pipeline custom code
|
||||
|
||||
If you want to override a specific pipeline.
|
||||
|
@ -27,7 +27,6 @@ from contextlib import contextmanager
|
||||
from os.path import abspath, exists
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
|
||||
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
||||
@ -47,7 +46,7 @@ if is_tf_available():
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from ..models.auto.modeling_auto import AutoModel
|
||||
else:
|
||||
@ -128,9 +127,15 @@ def pad_collate_fn(tokenizer, feature_extractor):
|
||||
f"The elements of the batch contain different keys. Cannot batch them ({set(item.keys())} != {keys})"
|
||||
)
|
||||
# input_values, input_pixels, input_ids, ...
|
||||
padded = {
|
||||
key: _pad(items, key, padding_value if key.startswith("input_") else 0, padding_side) for key in keys
|
||||
}
|
||||
padded = {}
|
||||
for key in keys:
|
||||
if key.startswith("input_"):
|
||||
_padding_value = padding_value
|
||||
elif key == "p_mask":
|
||||
_padding_value = 1
|
||||
else:
|
||||
_padding_value = 0
|
||||
padded[key] = _pad(items, key, _padding_value, padding_side)
|
||||
return padded
|
||||
|
||||
return inner
|
||||
@ -676,127 +681,12 @@ PIPELINE_INIT_ARGS = r"""
|
||||
"""
|
||||
|
||||
if is_torch_available():
|
||||
|
||||
class PipelineDataset(Dataset):
|
||||
def __init__(self, dataset, process, params):
|
||||
self.dataset = dataset
|
||||
self.process = process
|
||||
self.params = params
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, i):
|
||||
item = self.dataset[i]
|
||||
processed = self.process(item, **self.params)
|
||||
return processed
|
||||
|
||||
class PipelineIterator(IterableDataset):
|
||||
def __init__(self, loader, infer, params, loader_batch_size=None):
|
||||
"""
|
||||
Roughly equivalent to
|
||||
|
||||
```python
|
||||
for item in loader:
|
||||
yield infer(item, **params)
|
||||
```
|
||||
|
||||
Arguments:
|
||||
loader (`torch.utils.data.DataLoader` or any iterator):
|
||||
The iterator that will be used to apply `infer` on.
|
||||
infer (any function):
|
||||
The function to apply of each element of `loader`.
|
||||
params (`dict`):
|
||||
The parameters passed to `infer` along with every item
|
||||
loader_batch_size (`int`, *optional*):
|
||||
If specified, the items of `loader` are supposed to come as batch, and are loader_batched here
|
||||
making it roughly behave as
|
||||
|
||||
|
||||
```python
|
||||
for items in loader:
|
||||
for i in loader_batch_size:
|
||||
item = items[i]
|
||||
yield infer(item, **params)
|
||||
```"""
|
||||
self.loader = loader
|
||||
self.infer = infer
|
||||
self.params = params
|
||||
if loader_batch_size == 1:
|
||||
# Let's spare some time by deactivating altogether
|
||||
loader_batch_size = None
|
||||
self.loader_batch_size = loader_batch_size
|
||||
|
||||
# Internal bookkeeping
|
||||
self._loader_batch_index = None
|
||||
self._loader_batch_data = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.loader)
|
||||
|
||||
def __iter__(self):
|
||||
self.iterator = iter(self.loader)
|
||||
return self
|
||||
|
||||
def loader_batch_item(self):
|
||||
if isinstance(self._loader_batch_data, torch.Tensor):
|
||||
result = self._loader_batch_data[self._loader_batch_index]
|
||||
else:
|
||||
loader_batched = {}
|
||||
for k, element in self._loader_batch_data.items():
|
||||
if k in {"hidden_states", "past_key_values", "attentions"} and isinstance(element, tuple):
|
||||
if isinstance(element[0], torch.Tensor):
|
||||
loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element)
|
||||
elif isinstance(element[0], np.ndarray):
|
||||
loader_batched[k] = tuple(
|
||||
np.expand_dims(el[self._loader_batch_index], 0) for el in element
|
||||
)
|
||||
elif isinstance(element[self._loader_batch_index], torch.Tensor):
|
||||
loader_batched[k] = element[self._loader_batch_index].unsqueeze(0)
|
||||
elif isinstance(element[self._loader_batch_index], np.ndarray):
|
||||
loader_batched[k] = np.expand_dims(element[self._loader_batch_index], 0)
|
||||
else:
|
||||
loader_batched[k] = element[self._loader_batch_index]
|
||||
result = self._loader_batch_data.__class__(loader_batched)
|
||||
self._loader_batch_index += 1
|
||||
return result
|
||||
|
||||
def __next__(self):
|
||||
if self._loader_batch_index is not None and self._loader_batch_index < self.loader_batch_size:
|
||||
return self.loader_batch_item()
|
||||
|
||||
item = next(self.iterator)
|
||||
processed = self.infer(item, **self.params)
|
||||
if self.loader_batch_size is not None:
|
||||
if isinstance(processed, torch.Tensor):
|
||||
first_tensor = processed
|
||||
else:
|
||||
key = list(processed.keys())[0]
|
||||
first_tensor = processed[key]
|
||||
if isinstance(first_tensor, list):
|
||||
observed_batch_size = len(first_tensor)
|
||||
else:
|
||||
observed_batch_size = first_tensor.shape[0]
|
||||
if 0 < observed_batch_size < self.loader_batch_size:
|
||||
# Could be last batch so we can't unroll as many
|
||||
# elements.
|
||||
self.loader_batch_size = observed_batch_size
|
||||
self._loader_batch_data = processed
|
||||
self._loader_batch_index = 0
|
||||
return self.loader_batch_item()
|
||||
else:
|
||||
return processed
|
||||
|
||||
class KeyDataset(Dataset):
|
||||
def __init__(self, dataset: Dataset, key: str):
|
||||
self.dataset = dataset
|
||||
self.key = key
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.dataset[i][self.key]
|
||||
from transformers.pipelines.pt_utils import (
|
||||
PipelineChunkIterator,
|
||||
PipelineDataset,
|
||||
PipelineIterator,
|
||||
PipelinePackIterator,
|
||||
)
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
@ -1076,8 +966,18 @@ class Pipeline(_ScikitCompat):
|
||||
"You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset",
|
||||
UserWarning,
|
||||
)
|
||||
if isinstance(inputs, list):
|
||||
if self.framework == "pt":
|
||||
|
||||
is_dataset = Dataset is not None and isinstance(inputs, Dataset)
|
||||
is_generator = isinstance(inputs, types.GeneratorType)
|
||||
is_list = isinstance(inputs, list)
|
||||
|
||||
is_iterable = is_dataset or is_generator or is_list
|
||||
|
||||
# TODO make the get_iterator work also for `tf` (and `flax`).
|
||||
can_use_iterator = self.framework == "pt" and (is_dataset or is_generator or is_list)
|
||||
|
||||
if is_list:
|
||||
if can_use_iterator:
|
||||
final_iterator = self.get_iterator(
|
||||
inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
|
||||
)
|
||||
@ -1085,18 +985,12 @@ class Pipeline(_ScikitCompat):
|
||||
return outputs
|
||||
else:
|
||||
return self.run_multi(inputs, preprocess_params, forward_params, postprocess_params)
|
||||
elif Dataset is not None and isinstance(inputs, Dataset):
|
||||
elif can_use_iterator:
|
||||
return self.get_iterator(
|
||||
inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
|
||||
)
|
||||
elif isinstance(inputs, types.GeneratorType):
|
||||
if self.framework == "pt":
|
||||
return self.get_iterator(
|
||||
inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
|
||||
)
|
||||
else:
|
||||
# TODO make the get_iterator work also for `tf` (and `flax`).
|
||||
return self.iterate(inputs, preprocess_params, forward_params, postprocess_params)
|
||||
elif is_iterable:
|
||||
return self.iterate(inputs, preprocess_params, forward_params, postprocess_params)
|
||||
else:
|
||||
return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
|
||||
|
||||
@ -1114,3 +1008,31 @@ class Pipeline(_ScikitCompat):
|
||||
# easy solution.
|
||||
for input_ in inputs:
|
||||
yield self.run_single(input_, preprocess_params, forward_params, postprocess_params)
|
||||
|
||||
|
||||
class ChunkPipeline(Pipeline):
|
||||
def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
|
||||
all_outputs = []
|
||||
for model_inputs in self.preprocess(inputs, **preprocess_params):
|
||||
model_outputs = self.forward(model_inputs, **forward_params)
|
||||
all_outputs.append(model_outputs)
|
||||
outputs = self.postprocess(all_outputs, **postprocess_params)
|
||||
return outputs
|
||||
|
||||
def get_iterator(
|
||||
self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params
|
||||
):
|
||||
if "TOKENIZERS_PARALLELISM" not in os.environ:
|
||||
logger.info("Disabling tokenizer parallelism, we're using DataLoader multithreading already")
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
if num_workers > 1:
|
||||
logger.warning(
|
||||
"For ChunkPipeline using num_workers>0 is likely to result in errors since everything is iterable, setting `num_workers=1` to guarantee correctness."
|
||||
)
|
||||
num_workers = 1
|
||||
dataset = PipelineChunkIterator(inputs, self.preprocess, preprocess_params)
|
||||
collate_fn = no_collate_fn if batch_size == 1 else pad_collate_fn(self.tokenizer, self.feature_extractor)
|
||||
dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn)
|
||||
model_iterator = PipelinePackIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size)
|
||||
final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
|
||||
return final_iterator
|
||||
|
292
src/transformers/pipelines/pt_utils.py
Normal file
292
src/transformers/pipelines/pt_utils.py
Normal file
@ -0,0 +1,292 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset, IterableDataset
|
||||
|
||||
|
||||
class PipelineDataset(Dataset):
|
||||
def __init__(self, dataset, process, params):
|
||||
self.dataset = dataset
|
||||
self.process = process
|
||||
self.params = params
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, i):
|
||||
item = self.dataset[i]
|
||||
processed = self.process(item, **self.params)
|
||||
return processed
|
||||
|
||||
|
||||
class PipelineIterator(IterableDataset):
|
||||
def __init__(self, loader, infer, params, loader_batch_size=None):
|
||||
"""
|
||||
Roughly equivalent to
|
||||
|
||||
```
|
||||
for item in loader:
|
||||
yield infer(item, **params)
|
||||
```
|
||||
|
||||
Arguments:
|
||||
loader (`torch.utils.data.DataLoader` or any iterator):
|
||||
The iterator that will be used to apply `infer` on.
|
||||
infer (any function):
|
||||
The function to apply of each element of `loader`.
|
||||
params (`dict`):
|
||||
The parameters passed to `infer` along with every item
|
||||
loader_batch_size (`int`, *optional*):
|
||||
If specified, the items of `loader` are supposed to come as batch, and are loader_batched here
|
||||
making it roughly behave as
|
||||
|
||||
|
||||
```
|
||||
for items in loader:
|
||||
for i in loader_batch_size:
|
||||
item = items[i]
|
||||
yield infer(item, **params)
|
||||
```"""
|
||||
self.loader = loader
|
||||
self.infer = infer
|
||||
self.params = params
|
||||
if loader_batch_size == 1:
|
||||
# Let's spare some time by deactivating altogether
|
||||
loader_batch_size = None
|
||||
self.loader_batch_size = loader_batch_size
|
||||
|
||||
# Internal bookkeeping
|
||||
self._loader_batch_index = None
|
||||
self._loader_batch_data = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.loader)
|
||||
|
||||
def __iter__(self):
|
||||
self.iterator = iter(self.loader)
|
||||
return self
|
||||
|
||||
def loader_batch_item(self):
|
||||
"""
|
||||
Return item located at `loader_batch_index` within the current `loader_batch_data`.
|
||||
"""
|
||||
if isinstance(self._loader_batch_data, torch.Tensor):
|
||||
# Batch data is simple tensor, just fetch the slice
|
||||
result = self._loader_batch_data[self._loader_batch_index]
|
||||
else:
|
||||
# Batch data is assumed to be BaseModelOutput (or dict)
|
||||
loader_batched = {}
|
||||
for k, element in self._loader_batch_data.items():
|
||||
if k in {"hidden_states", "past_key_values", "attentions"} and isinstance(element, tuple):
|
||||
# Those are stored as lists of tensors so need specific unbatching.
|
||||
if isinstance(element[0], torch.Tensor):
|
||||
loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element)
|
||||
elif isinstance(element[0], np.ndarray):
|
||||
loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element)
|
||||
continue
|
||||
if isinstance(element[self._loader_batch_index], torch.Tensor):
|
||||
# Take correct batch data, but make it looked like batch_size=1
|
||||
# For compatibility with other methods within transformers
|
||||
|
||||
loader_batched[k] = element[self._loader_batch_index].unsqueeze(0)
|
||||
elif isinstance(element[self._loader_batch_index], np.ndarray):
|
||||
# Take correct batch data, but make it looked like batch_size=1
|
||||
# For compatibility with other methods within transformers
|
||||
loader_batched[k] = np.expand_dims(element[self._loader_batch_index], 0)
|
||||
else:
|
||||
# This is typically a list, so no need to `unsqueeze`.
|
||||
loader_batched[k] = element[self._loader_batch_index]
|
||||
# Recreate the element by reusing the original class to make it look
|
||||
# batch_size=1
|
||||
result = self._loader_batch_data.__class__(loader_batched)
|
||||
self._loader_batch_index += 1
|
||||
return result
|
||||
|
||||
def __next__(self):
|
||||
if self._loader_batch_index is not None and self._loader_batch_index < self.loader_batch_size:
|
||||
# We are currently unrolling a batch so we just need to return
|
||||
# the current item within a batch
|
||||
return self.loader_batch_item()
|
||||
|
||||
# We're out of items within a batch
|
||||
item = next(self.iterator)
|
||||
processed = self.infer(item, **self.params)
|
||||
# We now have a batch of "inferred things".
|
||||
if self.loader_batch_size is not None:
|
||||
# Try to infer the size of the batch
|
||||
if isinstance(processed, torch.Tensor):
|
||||
first_tensor = processed
|
||||
else:
|
||||
key = list(processed.keys())[0]
|
||||
first_tensor = processed[key]
|
||||
if isinstance(first_tensor, list):
|
||||
observed_batch_size = len(first_tensor)
|
||||
else:
|
||||
observed_batch_size = first_tensor.shape[0]
|
||||
if 0 < observed_batch_size < self.loader_batch_size:
|
||||
# could be last batch so we can't unroll as many
|
||||
# elements.
|
||||
self.loader_batch_size = observed_batch_size
|
||||
# Setting internal index to unwrap the batch
|
||||
self._loader_batch_data = processed
|
||||
self._loader_batch_index = 0
|
||||
return self.loader_batch_item()
|
||||
else:
|
||||
# We're not unrolling batches
|
||||
return processed
|
||||
|
||||
|
||||
class PipelineChunkIterator(PipelineIterator):
|
||||
def __init__(self, loader, infer, params, loader_batch_size=None):
|
||||
"""
|
||||
Roughly equivalent to
|
||||
|
||||
```
|
||||
for iterator in loader:
|
||||
for item in iterator:
|
||||
yield infer(item, **params)
|
||||
```
|
||||
|
||||
Arguments:
|
||||
loader (`torch.utils.data.DataLoader` or any iterator):
|
||||
The iterator that will be used to apply `infer` on.
|
||||
infer (any function):
|
||||
The function to apply of each element of `loader`.
|
||||
params (`dict`):
|
||||
The parameters passed to `infer` along with every item
|
||||
"""
|
||||
super().__init__(loader, infer, params)
|
||||
|
||||
def __iter__(self):
|
||||
self.iterator = iter(self.loader)
|
||||
self.subiterator = None
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.subiterator is None:
|
||||
"Subiterator None means we haven't started a `preprocess` iterator. so start it"
|
||||
self.subiterator = self.infer(next(self.iterator), **self.params)
|
||||
try:
|
||||
# Try to return next item
|
||||
processed = next(self.subiterator)
|
||||
except StopIteration:
|
||||
# When a preprocess iterator ends, we can start lookig at the next item
|
||||
# ChunkIterator will keep feeding until ALL elements of iterator
|
||||
# all have created their subiterator and have been iterating against.
|
||||
#
|
||||
# Another way to look at it, is we're basically flattening lists of lists
|
||||
# into a single list, but with generators
|
||||
self.subiterator = self.infer(next(self.iterator), **self.params)
|
||||
processed = next(self.subiterator)
|
||||
return processed
|
||||
|
||||
|
||||
class PipelinePackIterator(PipelineIterator):
|
||||
"""
|
||||
Roughly equivalent to
|
||||
|
||||
```
|
||||
packed = []
|
||||
for item in loader:
|
||||
packed.append(item)
|
||||
if item["is_last"]:
|
||||
yield packed
|
||||
packed = []
|
||||
```
|
||||
|
||||
but it also handles cases where `item` are batched (meaning it's a dict of Tensor with first dimension > 1. In
|
||||
that case it does
|
||||
|
||||
```
|
||||
packed = []
|
||||
for batch in loader:
|
||||
# item is batched
|
||||
for item in batch:
|
||||
packed.append(item)
|
||||
if item["is_last"]:
|
||||
yield packed
|
||||
packed = []
|
||||
```
|
||||
|
||||
Arguments:
|
||||
loader (`torch.utils.data.DataLoader` or any iterator):
|
||||
The iterator that will be used to apply `infer` on.
|
||||
infer (any function):
|
||||
The function to apply of each element of `loader`.
|
||||
params (`dict`):
|
||||
The parameters passed to `infer` along with every item
|
||||
loader_batch_size (`int`, *optional*):
|
||||
If specified, the items of `loader` are supposed to come as batch, and are loader_batched here making
|
||||
it roughly behave as
|
||||
|
||||
|
||||
```
|
||||
for items in loader:
|
||||
for i in loader_batch_size:
|
||||
item = items[i]
|
||||
yield infer(item, **params)
|
||||
```"""
|
||||
|
||||
def __iter__(self):
|
||||
self.iterator = iter(self.loader)
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
# Extremely similar to PipelineIterator in its unpacking mechanism
|
||||
# BUT, we have an extra required item which is the presence of `is_last`
|
||||
# That is because everything is flattened by `PipelineChunkIterator` we
|
||||
# need to keep track of how to regroup here in the original `process`
|
||||
# boundaries so that `process` and `postprocess` see the same data.
|
||||
|
||||
# This iterator accumulates items (possibly while unbatching) until it
|
||||
# its a `is_last` and then just passes it on to the caller.
|
||||
is_last = False
|
||||
accumulator = []
|
||||
if self._loader_batch_index is not None and self._loader_batch_index < self.loader_batch_size:
|
||||
while self._loader_batch_index < self.loader_batch_size:
|
||||
item = self.loader_batch_item()
|
||||
is_last = item.pop("is_last")
|
||||
accumulator.append(item)
|
||||
if is_last:
|
||||
return accumulator
|
||||
|
||||
while not is_last:
|
||||
processed = self.infer(next(self.iterator), **self.params)
|
||||
if self.loader_batch_size is not None:
|
||||
if isinstance(processed, torch.Tensor):
|
||||
first_tensor = processed
|
||||
else:
|
||||
key = list(processed.keys())[0]
|
||||
first_tensor = processed[key]
|
||||
if isinstance(first_tensor, list):
|
||||
observed_batch_size = len(first_tensor)
|
||||
else:
|
||||
observed_batch_size = first_tensor.shape[0]
|
||||
if 0 < observed_batch_size < self.loader_batch_size:
|
||||
# could be last batch so we can't unroll as many
|
||||
# elements.
|
||||
self.loader_batch_size = observed_batch_size
|
||||
self._loader_batch_data = processed
|
||||
self._loader_batch_index = 0
|
||||
while self._loader_batch_index < self.loader_batch_size:
|
||||
item = self.loader_batch_item()
|
||||
is_last = item.pop("is_last")
|
||||
accumulator.append(item)
|
||||
if is_last:
|
||||
return accumulator
|
||||
else:
|
||||
item = processed
|
||||
is_last = item.pop("is_last")
|
||||
accumulator.append(item)
|
||||
return accumulator
|
||||
|
||||
|
||||
class KeyDataset(Dataset):
|
||||
def __init__(self, dataset: Dataset, key: str):
|
||||
self.dataset = dataset
|
||||
self.key = key
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.dataset[i][self.key]
|
@ -9,7 +9,7 @@ from ..file_utils import PaddingStrategy, add_end_docstrings, is_tf_available, i
|
||||
from ..modelcard import ModelCard
|
||||
from ..tokenization_utils import PreTrainedTokenizer
|
||||
from ..utils import logging
|
||||
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline
|
||||
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, ChunkPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -99,7 +99,7 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler):
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class QuestionAnsweringPipeline(Pipeline):
|
||||
class QuestionAnsweringPipeline(ChunkPipeline):
|
||||
"""
|
||||
Question Answering pipeline using any `ModelForQuestionAnswering`. See the [question answering examples](../task_summary#question-answering) for more information.
|
||||
|
||||
@ -242,9 +242,6 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
- **end** (`int`) -- The character end index of the answer (in the tokenized version of the input).
|
||||
- **answer** (`str`) -- The answer to the question.
|
||||
"""
|
||||
if kwargs.get("batch_size", 1) > 1:
|
||||
logger.error("Batch_size > 1 is not supported for question answering pipeline, setting it to 1.")
|
||||
kwargs["batch_size"] = 1
|
||||
|
||||
# Convert inputs to features
|
||||
examples = self._args_parser(*args, **kwargs)
|
||||
@ -343,11 +340,10 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
)
|
||||
)
|
||||
|
||||
split_features = []
|
||||
for feature in features:
|
||||
for i, feature in enumerate(features):
|
||||
fw_args = {}
|
||||
others = {}
|
||||
model_input_names = self.tokenizer.model_input_names
|
||||
model_input_names = self.tokenizer.model_input_names + ["p_mask"]
|
||||
|
||||
for k, v in feature.__dict__.items():
|
||||
if k in model_input_names:
|
||||
@ -363,20 +359,15 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
fw_args[k] = tensor.unsqueeze(0)
|
||||
else:
|
||||
others[k] = v
|
||||
split_features.append({"fw_args": fw_args, "others": others})
|
||||
return {"features": split_features, "example": example}
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
features = model_inputs["features"]
|
||||
example = model_inputs["example"]
|
||||
starts = []
|
||||
ends = []
|
||||
for feature in features:
|
||||
fw_args = feature["fw_args"]
|
||||
start, end = self.model(**fw_args)[:2]
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
return {"starts": starts, "ends": ends, "features": features, "example": example}
|
||||
is_last = i == len(features) - 1
|
||||
yield {"example": example, "is_last": is_last, **fw_args, **others}
|
||||
|
||||
def _forward(self, inputs):
|
||||
example = inputs["example"]
|
||||
model_inputs = {k: inputs[k] for k in self.tokenizer.model_input_names}
|
||||
start, end = self.model(**model_inputs)[:2]
|
||||
return {"start": start, "end": end, "example": example, **inputs}
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
@ -387,16 +378,16 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
):
|
||||
min_null_score = 1000000 # large and positive
|
||||
answers = []
|
||||
example = model_outputs["example"]
|
||||
for i, (feature_, start_, end_) in enumerate(
|
||||
zip(model_outputs["features"], model_outputs["starts"], model_outputs["ends"])
|
||||
):
|
||||
feature = feature_["others"]
|
||||
# Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
|
||||
undesired_tokens = np.abs(np.array(feature["p_mask"]) - 1)
|
||||
for output in model_outputs:
|
||||
start_ = output["start"]
|
||||
end_ = output["end"]
|
||||
example = output["example"]
|
||||
|
||||
if feature_["fw_args"].get("attention_mask", None) is not None:
|
||||
undesired_tokens = undesired_tokens & feature_["fw_args"]["attention_mask"].numpy()
|
||||
# Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
|
||||
undesired_tokens = np.abs(np.array(output["p_mask"]) - 1)
|
||||
|
||||
if output.get("attention_mask", None) is not None:
|
||||
undesired_tokens = undesired_tokens & output["attention_mask"].numpy()
|
||||
|
||||
# Generate mask
|
||||
undesired_tokens_mask = undesired_tokens == 0.0
|
||||
@ -425,7 +416,7 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
# End: Index of the character following the last character of the answer in the context string
|
||||
# Answer: Plain text of the answer
|
||||
for s, e, score in zip(starts, ends, scores):
|
||||
token_to_orig_map = feature["token_to_orig_map"]
|
||||
token_to_orig_map = output["token_to_orig_map"]
|
||||
answers.append(
|
||||
{
|
||||
"score": score.item(),
|
||||
@ -441,7 +432,7 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
# End: Index of the character following the last character of the answer in the context string
|
||||
# Answer: Plain text of the answer
|
||||
question_first = bool(self.tokenizer.padding_side == "right")
|
||||
enc = feature["encoding"]
|
||||
enc = output["encoding"]
|
||||
|
||||
# Sometimes the max probability token is in the middle of a word so:
|
||||
# - we start by finding the right word containing the token with `token_to_word`
|
||||
|
@ -5,7 +5,7 @@ import numpy as np
|
||||
from ..file_utils import add_end_docstrings
|
||||
from ..tokenization_utils import TruncationStrategy
|
||||
from ..utils import logging
|
||||
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline
|
||||
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, ChunkPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -44,7 +44,7 @@ class ZeroShotClassificationArgumentHandler(ArgumentHandler):
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class ZeroShotClassificationPipeline(Pipeline):
|
||||
class ZeroShotClassificationPipeline(ChunkPipeline):
|
||||
"""
|
||||
NLI-based zero-shot classification pipeline using a `ModelForSequenceClassification` trained on NLI (natural
|
||||
language inference) tasks.
|
||||
@ -84,48 +84,37 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
|
||||
"""
|
||||
return_tensors = self.framework
|
||||
if getattr(self.tokenizer, "pad_token", None) is None:
|
||||
# XXX some tokenizers do not have a padding token, we use simple lists
|
||||
# and no padding then
|
||||
logger.warning("The tokenizer {self.tokenizer} does not have a pad token, we're not running it as a batch")
|
||||
padding = False
|
||||
inputs = []
|
||||
for sequence_pair in sequence_pairs:
|
||||
model_input = self.tokenizer(
|
||||
text=sequence_pair[0],
|
||||
text_pair=sequence_pair[1],
|
||||
add_special_tokens=add_special_tokens,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
)
|
||||
inputs.append(model_input)
|
||||
else:
|
||||
try:
|
||||
if self.tokenizer.pad_token is None:
|
||||
# Override for tokenizers not supporting padding
|
||||
logger.error(
|
||||
"Tokenizer was not supporting padding necessary for zero-shot, attempting to use `pad_token=eos_token`"
|
||||
)
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
try:
|
||||
inputs = self.tokenizer(
|
||||
sequence_pairs,
|
||||
add_special_tokens=add_special_tokens,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
)
|
||||
except Exception as e:
|
||||
if "too short" in str(e):
|
||||
# tokenizers might yell that we want to truncate
|
||||
# to a value that is not even reached by the input.
|
||||
# In that case we don't want to truncate.
|
||||
# It seems there's not a really better way to catch that
|
||||
# exception.
|
||||
|
||||
inputs = self.tokenizer(
|
||||
sequence_pairs,
|
||||
add_special_tokens=add_special_tokens,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
|
||||
)
|
||||
except Exception as e:
|
||||
if "too short" in str(e):
|
||||
# tokenizers might yell that we want to truncate
|
||||
# to a value that is not even reached by the input.
|
||||
# In that case we don't want to truncate.
|
||||
# It seems there's not a really better way to catch that
|
||||
# exception.
|
||||
|
||||
inputs = self.tokenizer(
|
||||
sequence_pairs,
|
||||
add_special_tokens=add_special_tokens,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
raise e
|
||||
|
||||
return inputs
|
||||
|
||||
@ -183,10 +172,6 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
- **labels** (`List[str]`) -- The labels sorted by order of likelihood.
|
||||
- **scores** (`List[float]`) -- The probabilities for each of the labels.
|
||||
"""
|
||||
if kwargs.get("batch_size", 1) > 1:
|
||||
logger.error("Batch size > 1 is not supported for zero-shot pipeline, setting batch_size=1.")
|
||||
kwargs["batch_size"] = 1
|
||||
|
||||
if len(args) == 0:
|
||||
pass
|
||||
elif len(args) == 1 and "candidate_labels" not in kwargs:
|
||||
@ -198,45 +183,35 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
|
||||
def preprocess(self, inputs, candidate_labels=None, hypothesis_template="This example is {}."):
|
||||
sequence_pairs, sequences = self._args_parser(inputs, candidate_labels, hypothesis_template)
|
||||
model_inputs = self._parse_and_tokenize(sequence_pairs)
|
||||
|
||||
prepared_inputs = {
|
||||
"candidate_labels": candidate_labels,
|
||||
"sequences": sequences,
|
||||
"inputs": model_inputs,
|
||||
}
|
||||
return prepared_inputs
|
||||
for i, (candidate_label, sequence_pair) in enumerate(zip(candidate_labels, sequence_pairs)):
|
||||
model_input = self._parse_and_tokenize([sequence_pair])
|
||||
|
||||
yield {
|
||||
"candidate_label": candidate_label,
|
||||
"sequence": sequences[0],
|
||||
"is_last": i == len(candidate_labels) - 1,
|
||||
**model_input,
|
||||
}
|
||||
|
||||
def _forward(self, inputs):
|
||||
candidate_labels = inputs["candidate_labels"]
|
||||
sequences = inputs["sequences"]
|
||||
model_inputs = inputs["inputs"]
|
||||
if isinstance(model_inputs, list):
|
||||
outputs = []
|
||||
for input_ in model_inputs:
|
||||
prediction = self.model(**input_)[0].cpu()
|
||||
outputs.append(prediction)
|
||||
else:
|
||||
outputs = self.model(**model_inputs)
|
||||
candidate_label = inputs["candidate_label"]
|
||||
sequence = inputs["sequence"]
|
||||
model_inputs = {k: inputs[k] for k in self.tokenizer.model_input_names}
|
||||
outputs = self.model(**model_inputs)
|
||||
|
||||
model_outputs = {"candidate_labels": candidate_labels, "sequences": sequences, "outputs": outputs}
|
||||
model_outputs = {
|
||||
"candidate_label": candidate_label,
|
||||
"sequence": sequence,
|
||||
"is_last": inputs["is_last"],
|
||||
**outputs,
|
||||
}
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs, multi_label=False):
|
||||
candidate_labels = model_outputs["candidate_labels"]
|
||||
sequences = model_outputs["sequences"]
|
||||
outputs = model_outputs["outputs"]
|
||||
|
||||
if self.framework == "pt":
|
||||
if isinstance(outputs, list):
|
||||
logits = np.concatenate([output.cpu().numpy() for output in outputs], axis=0)
|
||||
else:
|
||||
logits = outputs["logits"].cpu().numpy()
|
||||
else:
|
||||
if isinstance(outputs, list):
|
||||
logits = np.concatenate([output.numpy() for output in outputs], axis=0)
|
||||
else:
|
||||
logits = outputs["logits"].numpy()
|
||||
candidate_labels = [outputs["candidate_label"] for outputs in model_outputs]
|
||||
sequences = [outputs["sequence"] for outputs in model_outputs]
|
||||
logits = np.concatenate([output["logits"].numpy() for output in model_outputs])
|
||||
N = logits.shape[0]
|
||||
n = len(candidate_labels)
|
||||
num_sequences = N // n
|
||||
@ -254,16 +229,9 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
entail_logits = reshaped_outputs[..., self.entailment_id]
|
||||
scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True)
|
||||
|
||||
result = []
|
||||
for iseq in range(num_sequences):
|
||||
top_inds = list(reversed(scores[iseq].argsort()))
|
||||
result.append(
|
||||
{
|
||||
"sequence": sequences[iseq],
|
||||
"labels": [candidate_labels[i] for i in top_inds],
|
||||
"scores": scores[iseq, top_inds].tolist(),
|
||||
}
|
||||
)
|
||||
if len(result) == 1:
|
||||
return result[0]
|
||||
return result
|
||||
top_inds = list(reversed(scores[0].argsort()))
|
||||
return {
|
||||
"sequence": sequences[0],
|
||||
"labels": [candidate_labels[i] for i in top_inds],
|
||||
"scores": scores[0, top_inds].tolist(),
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_CTC_MAPPING,
|
||||
@ -72,7 +73,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
import numpy as np
|
||||
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
@ -101,7 +101,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
@require_torch
|
||||
@slow
|
||||
def test_torch_large(self):
|
||||
import numpy as np
|
||||
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
@ -113,8 +112,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
output = speech_recognizer(waveform)
|
||||
self.assertEqual(output, {"text": ""})
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
@ -130,8 +127,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
framework="pt",
|
||||
)
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
@ -140,8 +135,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
@slow
|
||||
@require_torch
|
||||
def test_simple_wav2vec2(self):
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
@ -168,8 +161,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
def test_simple_s2t(self):
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-mustc-en-it-st")
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/s2t-small-mustc-en-it-st")
|
||||
@ -204,8 +195,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
framework="pt",
|
||||
)
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
@ -222,8 +211,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
framework="pt",
|
||||
)
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
|
@ -204,8 +204,10 @@ class PipelineTestCaseMeta(type):
|
||||
# Need to copy because Conversation object is mutated
|
||||
yield copy.deepcopy(random.choice(examples))
|
||||
|
||||
out = []
|
||||
for item in pipeline(data(10), batch_size=4):
|
||||
pass
|
||||
out.append(item)
|
||||
self.assertEqual(len(out), 10)
|
||||
|
||||
run_batch_test(pipeline, examples)
|
||||
|
||||
@ -444,3 +446,141 @@ class PipelinePadTest(unittest.TestCase):
|
||||
torch.zeros((2, 11, 2), dtype=torch.long),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
@require_torch
|
||||
class PipelineUtilsTest(unittest.TestCase):
|
||||
def test_pipeline_dataset(self):
|
||||
from transformers.pipelines.pt_utils import PipelineDataset
|
||||
|
||||
dummy_dataset = [0, 1, 2, 3]
|
||||
|
||||
def add(number, extra=0):
|
||||
return number + extra
|
||||
|
||||
dataset = PipelineDataset(dummy_dataset, add, {"extra": 2})
|
||||
self.assertEqual(len(dataset), 4)
|
||||
outputs = [dataset[i] for i in range(4)]
|
||||
self.assertEqual(outputs, [2, 3, 4, 5])
|
||||
|
||||
def test_pipeline_iterator(self):
|
||||
from transformers.pipelines.pt_utils import PipelineIterator
|
||||
|
||||
dummy_dataset = [0, 1, 2, 3]
|
||||
|
||||
def add(number, extra=0):
|
||||
return number + extra
|
||||
|
||||
dataset = PipelineIterator(dummy_dataset, add, {"extra": 2})
|
||||
self.assertEqual(len(dataset), 4)
|
||||
|
||||
outputs = [item for item in dataset]
|
||||
self.assertEqual(outputs, [2, 3, 4, 5])
|
||||
|
||||
def test_pipeline_iterator_no_len(self):
|
||||
from transformers.pipelines.pt_utils import PipelineIterator
|
||||
|
||||
def dummy_dataset():
|
||||
for i in range(4):
|
||||
yield i
|
||||
|
||||
def add(number, extra=0):
|
||||
return number + extra
|
||||
|
||||
dataset = PipelineIterator(dummy_dataset(), add, {"extra": 2})
|
||||
with self.assertRaises(TypeError):
|
||||
len(dataset)
|
||||
|
||||
outputs = [item for item in dataset]
|
||||
self.assertEqual(outputs, [2, 3, 4, 5])
|
||||
|
||||
def test_pipeline_batch_unbatch_iterator(self):
|
||||
from transformers.pipelines.pt_utils import PipelineIterator
|
||||
|
||||
dummy_dataset = [{"id": [0, 1, 2]}, {"id": [3]}]
|
||||
|
||||
def add(number, extra=0):
|
||||
return {"id": [i + extra for i in number["id"]]}
|
||||
|
||||
dataset = PipelineIterator(dummy_dataset, add, {"extra": 2}, loader_batch_size=3)
|
||||
|
||||
outputs = [item for item in dataset]
|
||||
self.assertEqual(outputs, [{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}])
|
||||
|
||||
def test_pipeline_batch_unbatch_iterator_tensors(self):
|
||||
import torch
|
||||
|
||||
from transformers.pipelines.pt_utils import PipelineIterator
|
||||
|
||||
dummy_dataset = [{"id": torch.LongTensor([[10, 20], [0, 1], [0, 2]])}, {"id": torch.LongTensor([[3]])}]
|
||||
|
||||
def add(number, extra=0):
|
||||
return {"id": number["id"] + extra}
|
||||
|
||||
dataset = PipelineIterator(dummy_dataset, add, {"extra": 2}, loader_batch_size=3)
|
||||
|
||||
outputs = [item for item in dataset]
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs), [{"id": [[12, 22]]}, {"id": [[2, 3]]}, {"id": [[2, 4]]}, {"id": [[5]]}]
|
||||
)
|
||||
|
||||
def test_pipeline_chunk_iterator(self):
|
||||
from transformers.pipelines.pt_utils import PipelineChunkIterator
|
||||
|
||||
def preprocess_chunk(n: int):
|
||||
for i in range(n):
|
||||
yield i
|
||||
|
||||
dataset = [2, 3]
|
||||
|
||||
dataset = PipelineChunkIterator(dataset, preprocess_chunk, {}, loader_batch_size=3)
|
||||
|
||||
outputs = [item for item in dataset]
|
||||
|
||||
self.assertEqual(outputs, [0, 1, 0, 1, 2])
|
||||
|
||||
def test_pipeline_pack_iterator(self):
|
||||
from transformers.pipelines.pt_utils import PipelinePackIterator
|
||||
|
||||
def pack(item):
|
||||
return {"id": item["id"] + 1, "is_last": item["is_last"]}
|
||||
|
||||
dataset = [
|
||||
{"id": 0, "is_last": False},
|
||||
{"id": 1, "is_last": True},
|
||||
{"id": 0, "is_last": False},
|
||||
{"id": 1, "is_last": False},
|
||||
{"id": 2, "is_last": True},
|
||||
]
|
||||
|
||||
dataset = PipelinePackIterator(dataset, pack, {})
|
||||
|
||||
outputs = [item for item in dataset]
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[
|
||||
{"id": 1},
|
||||
{"id": 2},
|
||||
],
|
||||
[
|
||||
{"id": 1},
|
||||
{"id": 2},
|
||||
{"id": 3},
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
def test_pipeline_pack_unbatch_iterator(self):
|
||||
from transformers.pipelines.pt_utils import PipelinePackIterator
|
||||
|
||||
dummy_dataset = [{"id": [0, 1, 2], "is_last": [False, True, False]}, {"id": [3], "is_last": [True]}]
|
||||
|
||||
def add(number, extra=0):
|
||||
return {"id": [i + extra for i in number["id"]], "is_last": number["is_last"]}
|
||||
|
||||
dataset = PipelinePackIterator(dummy_dataset, add, {"extra": 2}, loader_batch_size=3)
|
||||
|
||||
outputs = [item for item in dataset]
|
||||
self.assertEqual(outputs, [[{"id": 2}, {"id": 3}], [{"id": 4}, {"id": 5}]])
|
||||
|
Loading…
Reference in New Issue
Block a user