Adding support for microphone streaming within pipeline. (#15046)

* Adding support for `microphone` streaming within pipeline.

- Uses `ffmpeg` to get microphone data.
- Makes sure alignment is made to `size_of_sample`.
- Works by sending `{"raw": ..data.., "stride": (n, left, right),
"partial": bool}`
directly to the pipeline enabling to stream partial results and still
get inference.
- Let's `partial` information flow through the pipeline to enable caller
  to get it back and choose to display text or not.

- The striding reconstitution is bound to have errors since CTC does not
keep previous state. Currently most of the errors are we don't know if
there's a space or not between two chunks.
Since we have some left striding info, we could use that during decoding
to choose what to do with those spaces and even extra letters maybe (if
the stride is long enough, it's bound to cover at least a few symbols)

Fixing tests.

Protecting with `require_torch`.

`raw_ctc` support for nicer demo.

Post rebase fixes.

Revamp to split raw_mic_data from it's live chunking.

- Requires a refactor to make everything a bit cleaner.

Automatic resampling.

Small fix.

Small fix.

* Post rebase fix (need to let super handle more logic, reorder args.)

* Update docstrings

* Docstring format.

* Remove print.

* Prevent flow of `input_values`.

* Fixing `stride` too.

* Fixing the PR by removing `raw_ctc`.

* Better docstrings.

* Fixing init.

* Update src/transformers/pipelines/audio_utils.py

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>

* Update tests/test_pipelines_automatic_speech_recognition.py

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>

* Quality.

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
This commit is contained in:
Nicolas Patry 2022-02-02 15:12:12 +01:00 committed by GitHub
parent d718c0c3a8
commit 623d8cb475
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 413 additions and 62 deletions

View File

@ -0,0 +1,216 @@
import platform
import subprocess
from typing import Optional, Tuple, Union
import numpy as np
def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
"""
Helper function to read an audio file through ffmpeg.
"""
ar = f"{sampling_rate}"
ac = "1"
format_for_conversion = "f32le"
ffmpeg_command = [
"ffmpeg",
"-i",
"pipe:0",
"-ac",
ac,
"-ar",
ar,
"-f",
format_for_conversion,
"-hide_banner",
"-loglevel",
"quiet",
"pipe:1",
]
try:
with subprocess.Popen(ffmpeg_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE) as ffmpeg_process:
output_stream = ffmpeg_process.communicate(bpayload)
except FileNotFoundError as error:
raise ValueError("ffmpeg was not found but is required to load audio files from filename") from error
out_bytes = output_stream[0]
audio = np.frombuffer(out_bytes, np.float32)
if audio.shape[0] == 0:
raise ValueError("Malformed soundfile")
return audio
def ffmpeg_microphone(
sampling_rate: int,
chunk_length_s: float,
format_for_conversion: str = "f32le",
):
"""
Helper function ro read raw microphone data.
"""
ar = f"{sampling_rate}"
ac = "1"
if format_for_conversion == "s16le":
size_of_sample = 2
elif format_for_conversion == "f32le":
size_of_sample = 4
else:
raise ValueError("Unhandled format `{format_for_conversion}`. Please use `s16le` or `f32le`")
system = platform.system()
if system == "Linux":
format_ = "alsa"
input_ = "default"
elif system == "Darwin":
format_ = "avfoundation"
input_ = ":0"
elif system == "Windows":
format_ = "dshow"
input_ = "default"
ffmpeg_command = [
"ffmpeg",
"-f",
format_,
"-i",
input_,
"-ac",
ac,
"-ar",
ar,
"-f",
format_for_conversion,
"-fflags",
"nobuffer",
"-hide_banner",
"-loglevel",
"quiet",
"pipe:1",
]
chunk_len = int(round(sampling_rate * chunk_length_s)) * size_of_sample
iterator = _ffmpeg_stream(ffmpeg_command, chunk_len)
for item in iterator:
yield item
def ffmpeg_microphone_live(
sampling_rate: int,
chunk_length_s: float,
stream_chunk_s: Optional[int] = None,
stride_length_s: Optional[Union[Tuple[float, float], float]] = None,
format_for_conversion: str = "f32le",
):
"""
Helper function to read audio from the microphone file through ffmpeg. This will output `partial` overlapping
chunks starting from `stream_chunk_s` (if it is defined) until `chunk_length_s` is reached. It will make use of
striding to avoid errors on the "sides" of the various chunks.
Arguments:
sampling_rate (`int`):
The sampling_rate to use when reading the data from the microphone. Try using the model's sampling_rate to
avoid resampling later.
chunk_length_s (`float` or `int`):
The length of the maximum chunk of audio to be sent returned. This includes the eventual striding.
stream_chunk_s (`float` or `int`)
The length of the minimal temporary audio to be returned.
stride_length_s (`float` or `int` or `(float, float)`, *optional*, defaults to `None`)
The length of the striding to be used. Stride is used to provide context to a model on the (left, right) of
an audio sample but without using that part to actually make the prediction. Setting this does not change
the length of the chunk.
format_for_conversion: (`str`, defalts to `f32le`)
The name of the format of the audio samples to be returned by ffmpeg. The standard is `f32le`, `s16le`
could also be used.
Return:
A generator yielding dictionaries of the following form
`{"sampling_rate": int, "raw": np.array(), "partial" bool}` With optionnally a `"stride" (int, int)` key if
`stride_length_s` is defined.
`stride` and `raw` are all expressed in `samples`, and `partial` is a boolean saying if the current yield item
is a whole chunk, or a partial temporary result to be later replaced by another larger chunk.
"""
if stream_chunk_s is not None:
chunk_s = stream_chunk_s
else:
chunk_s = chunk_length_s
microphone = ffmpeg_microphone(sampling_rate, chunk_s, format_for_conversion=format_for_conversion)
if format_for_conversion == "s16le":
dtype = np.int16
size_of_sample = 2
elif format_for_conversion == "f32le":
dtype = np.float32
size_of_sample = 4
else:
raise ValueError("Unhandled format `{format_for_conversion}`. Please use `s16le` or `f32le`")
if stride_length_s is None:
stride_length_s = chunk_length_s / 6
chunk_len = int(round(sampling_rate * chunk_length_s)) * size_of_sample
if isinstance(stride_length_s, (int, float)):
stride_length_s = [stride_length_s, stride_length_s]
stride_left = int(round(sampling_rate * stride_length_s[0])) * size_of_sample
stride_right = int(round(sampling_rate * stride_length_s[1])) * size_of_sample
for item in chunk_bytes_iter(microphone, chunk_len, stride=(stride_left, stride_right), stream=True):
# Put everything back in numpy scale
item["raw"] = np.frombuffer(item["raw"], dtype=dtype)
item["stride"] = (
item["stride"][0] // size_of_sample,
item["stride"][1] // size_of_sample,
)
item["sampling_rate"] = sampling_rate
yield item
def chunk_bytes_iter(iterator, chunk_len: int, stride: Tuple[int, int], stream: bool = False):
"""
Reads raw bytes from an iterator and does chunks of length `chunk_len`. Optionally adds `stride` to each chunks to
get overlaps. `stream` is used to return partial results even if a full `chunk_len` is not yet available.
"""
acc = b""
stride_left, stride_right = stride
if stride_left + stride_right >= chunk_len:
raise ValueError(
f"Stride needs to be strictly smaller than chunk_len: ({stride_left}, {stride_right}) vs {chunk_len}"
)
_stride_left = 0
for raw in iterator:
acc += raw
if stream and len(acc) < chunk_len:
stride = (_stride_left, 0)
yield {"raw": acc[:chunk_len], "stride": stride, "partial": True}
else:
while len(acc) >= chunk_len:
# We are flushing the accumulator
stride = (_stride_left, stride_right)
item = {"raw": acc[:chunk_len], "stride": stride}
if stream:
item["partial"] = False
yield item
_stride_left = stride_left
acc = acc[chunk_len - stride_left - stride_right :]
# Last chunk
if len(acc) > stride_left:
item = {"raw": acc, "stride": (_stride_left, 0)}
if stream:
item["partial"] = False
yield item
def _ffmpeg_stream(ffmpeg_command, buflen: int):
"""
Internal function to create the generator of data through ffmpeg
"""
bufsize = 2 ** 24 # 16Mo
try:
with subprocess.Popen(ffmpeg_command, stdout=subprocess.PIPE, bufsize=bufsize) as ffmpeg_process:
while True:
raw = ffmpeg_process.stdout.read(buflen)
if raw == b"":
break
yield raw
except FileNotFoundError as error:
raise ValueError("ffmpeg was not found but is required to stream audio files from filename") from error

View File

@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
from collections import defaultdict
from typing import TYPE_CHECKING, Union
import numpy as np
from ..file_utils import is_torch_available
from ..utils import logging
from .audio_utils import ffmpeg_read
from .base import ChunkPipeline
@ -30,42 +31,6 @@ if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
"""
Helper function to read an audio file through ffmpeg.
"""
ar = f"{sampling_rate}"
ac = "1"
format_for_conversion = "f32le"
ffmpeg_command = [
"ffmpeg",
"-i",
"pipe:0",
"-ac",
ac,
"-ar",
ar,
"-f",
format_for_conversion,
"-hide_banner",
"-loglevel",
"quiet",
"pipe:1",
]
try:
ffmpeg_process = subprocess.Popen(ffmpeg_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
except FileNotFoundError:
raise ValueError("ffmpeg was not found but is required to load audio files from filename")
output_stream = ffmpeg_process.communicate(bpayload)
out_bytes = output_stream[0]
audio = np.frombuffer(out_bytes, np.float32)
if audio.shape[0] == 0:
raise ValueError("Malformed soundfile")
return audio
def rescale_stride(tokens_or_logits, stride):
"""
Rescales the stride values from audio space to tokens/logits space.
@ -130,14 +95,14 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
def __init__(self, feature_extractor: Union["SequenceFeatureExtractor", str], *args, **kwargs):
"""
Arguments:
feature_extractor ([`SequenceFeatureExtractor`]):
The feature extractor that will be used by the pipeline to encode waveform for the model.
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting
from [`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
[`PreTrainedTokenizer`].
feature_extractor ([`SequenceFeatureExtractor`]):
The feature extractor that will be used by the pipeline to encode waveform for the model.
chunk_length_s (`float`, *optional*, defaults to 0):
The input length for in each chunk. If `0` then chunking is disabled (default). Only available for CTC
models.
@ -156,20 +121,15 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the
model on the associated CUDA device id.
"""
super().__init__(*args, **kwargs)
self.feature_extractor = feature_extractor
if self.framework == "tf":
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
self.check_model_type(dict(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.items() + MODEL_FOR_CTC_MAPPING.items()))
if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
self.type = "seq2seq"
elif (
self.feature_extractor._processor_class
and self.feature_extractor._processor_class.endswith("WithLM")
feature_extractor._processor_class
and feature_extractor._processor_class.endswith("WithLM")
and kwargs.get("decoder", None) is not None
):
self.decoder = kwargs["decoder"]
@ -177,6 +137,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
else:
self.type = "ctc"
if self.framework == "tf":
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
self.check_model_type(dict(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.items() + MODEL_FOR_CTC_MAPPING.items()))
def __call__(
self,
inputs: Union[np.ndarray, bytes, str],
@ -187,12 +152,21 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
information.
Args:
inputs (`np.ndarray` or `bytes` or `str`):
The inputs is either a raw waveform (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
at the correct sampling rate (no further check will be done) or a `str` that is the filename of the
audio file, the file will be read at the correct sampling rate to get the waveform using *ffmpeg*. This
requires *ffmpeg* to be installed on the system. If *inputs* is `bytes` it is supposed to be the
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
The inputs is either :
- `str` that is the filename of the
audio file, the file will be read at the correct sampling rate to get the waveform using *ffmpeg*.
This
requires *ffmpeg* to be installed on the system.
- `bytes` it is supposed to be the
content of an audio file and is interpreted by *ffmpeg* in the same way.
- (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
Raw audio at the correct sampling rate (no further check will be done)
- `dict` form can be used to pass raw audio sampled at arbirary `sampling_rate` and let
this pipeline do the resampling. The dict must be in the fomat `{"sampling_rate": int, "raw":
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to treat the
first `left` samples and last `right` samples to be ignored in decoding (but used at inference to
provide more context to the model). Only use `stride` with CTC models.
Return:
A `dict` with the following keys:
@ -208,6 +182,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
preprocess_params["chunk_length_s"] = kwargs["chunk_length_s"]
if "stride_length_s" in kwargs:
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
return preprocess_params, {}, {}
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
@ -218,8 +193,35 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if isinstance(inputs, bytes):
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
stride = None
extra = {}
if isinstance(inputs, dict):
stride = inputs.pop("stride", None)
_inputs = inputs.pop("raw")
in_sampling_rate = inputs.pop("sampling_rate")
extra = inputs
inputs = _inputs
if in_sampling_rate != self.feature_extractor.sampling_rate:
import torch
from torchaudio import functional as F
inputs = F.resample(
torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
).numpy()
ratio = self.feature_extractor.sampling_rate / in_sampling_rate
else:
ratio = 1
if stride is not None:
if stride[0] + stride[1] > inputs.shape[0]:
raise ValueError("Stride is too large for input")
# Stride needs to get the chunk length here, it's going to get
# swallowed by the `feature_extractor` later, and then batching
# can add extra data in the inputs, so we need to keep track
# of the original length in the stride so we can cut properly.
stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
if not isinstance(inputs, np.ndarray):
raise ValueError("We expect a numpy ndarray as input")
raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
@ -249,7 +251,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
yield {"is_last": True, **processed}
if stride is not None:
if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
raise ValueError("Stride is only usable with CTC models, try removing it")
processed["stride"] = stride
yield {"is_last": True, **processed, **extra}
def _forward(self, model_inputs):
is_last = model_inputs.pop("is_last")
@ -259,13 +266,20 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# attention mask length is different from expected text decoder `encoder_attention_mask` length
# `generate` magic to create the mask automatically won't work, we basically need to help
# it here.
# Consume values so we can let extra information flow freely through
# the pipeline (important for `partial` in microphone)
input_features = model_inputs.pop("input_features")
attention_mask = model_inputs.pop("attention_mask")
tokens = self.model.generate(
encoder_outputs=encoder(**model_inputs), attention_mask=model_inputs.get("attention_mask")
encoder_outputs=encoder(input_features=input_features, attention_mask=attention_mask),
attention_mask=attention_mask,
)
out = {"tokens": tokens}
elif self.type == "ctc_with_lm":
stride = model_inputs.pop("stride", None)
outputs = self.model(**model_inputs)
input_values = model_inputs.pop("input_values")
attention_mask = model_inputs.pop("attention_mask", None)
outputs = self.model(input_values=input_values, attention_mask=attention_mask)
logits = outputs.logits
out = {"logits": logits}
if stride is not None:
@ -278,7 +292,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
out["stride"] = rescale_stride(logits, stride)
elif self.type == "ctc":
stride = model_inputs.pop("stride", None)
outputs = self.model(**model_inputs)
# Consume values so we can let extra information flow freely through
# the pipeline (important for `partial` in microphone)
input_values = model_inputs.pop("input_values")
attention_mask = model_inputs.pop("attention_mask", None)
outputs = self.model(input_values=input_values, attention_mask=attention_mask)
tokens = outputs.logits.argmax(dim=-1)
if stride is not None:
if isinstance(stride, tuple):
@ -291,14 +309,16 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
outputs = self.model(**model_inputs)
tokens = outputs.logits.argmax(dim=-1)
out = {"tokens": tokens}
return {"is_last": is_last, **out}
# Leftover
extra = model_inputs
return {"is_last": is_last, **out, **extra}
def postprocess(self, model_outputs):
if self.type == "ctc_with_lm":
final_logits = []
for outputs in model_outputs:
logits = outputs["logits"].numpy()
stride = outputs.get("stride", None)
stride = outputs.pop("stride", None)
if stride is not None:
total_n, left, right = stride
# Total_n might be < logits.shape[1]
@ -316,4 +336,13 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
tokens = np.concatenate([outputs["tokens"].numpy() for outputs in model_outputs], axis=-1)
tokens = tokens.squeeze(0)
text = self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
return {"text": text}
extra = defaultdict(list)
for output in model_outputs:
output.pop("tokens", None)
output.pop("logits", None)
for k, v in output.items():
if k == "is_last":
continue
extra[k].append(v)
return {"text": text, **extra}

View File

@ -27,6 +27,7 @@ from transformers import (
Wav2Vec2ForCTC,
)
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.pipelines.audio_utils import chunk_bytes_iter
from transformers.pipelines.automatic_speech_recognition import apply_stride, chunk_iter
from transformers.testing_utils import (
is_pipeline_test,
@ -80,6 +81,15 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
outputs = speech_recognizer(audio)
self.assertEqual(outputs, {"text": ANY(str)})
audio = {"raw": audio, "stride": (0, 4000), "sampling_rate": speech_recognizer.feature_extractor.sampling_rate}
if speech_recognizer.type == "ctc":
outputs = speech_recognizer(audio)
self.assertEqual(outputs, {"text": ANY(str)})
else:
# Non CTC models cannot use striding.
with self.assertRaises(ValueError):
outputs = speech_recognizer(audio)
@require_torch
@slow
def test_pt_defaults(self):
@ -87,7 +97,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
@require_torch
def test_small_model_pt(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/s2t-small-mustc-en-fr-st",
@ -180,7 +189,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
@slow
@require_torch
def test_simple_wav2vec2(self):
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
@ -455,6 +463,28 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
# (85, 100)
self.assertEqual(nested_simplify(input_values[:, 80:100]), nested_simplify(outs[4]["input_values"]))
@require_torch
def test_stride(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="hf-internal-testing/tiny-random-wav2vec2",
)
waveform = np.tile(np.arange(1000, dtype=np.float32), 10)
output = speech_recognizer({"raw": waveform, "stride": (0, 0), "sampling_rate": 16_000})
self.assertEqual(output, {"text": "OB XB B EB BB B EB B OB X"})
# 0 effective ids Just take the middle one
output = speech_recognizer({"raw": waveform, "stride": (5000, 5000), "sampling_rate": 16_000})
self.assertEqual(output, {"text": "B"})
# Only 1 arange.
output = speech_recognizer({"raw": waveform, "stride": (0, 9000), "sampling_rate": 16_000})
self.assertEqual(output, {"text": "O"})
# 2nd arange
output = speech_recognizer({"raw": waveform, "stride": (1000, 8000), "sampling_rate": 16_000})
self.assertEqual(output, {"text": "B XB"})
@require_torch
class ApplyStrideTest(unittest.TestCase):
@ -488,3 +518,79 @@ class ApplyStrideTest(unittest.TestCase):
tokens = torch.arange(10).long().reshape((2, 5))
apply_stride(tokens, [(100, 20, 0), (60, 0, 20)])
self.assertEqual([[1, 1, 2, 3, 4], [5, 6, 6, 6, 6]], tokens.tolist())
def require_ffmpeg(test_case):
"""
Decorator marking a test that requires FFmpeg.
These tests are skipped when FFmpeg isn't installed.
"""
import subprocess
try:
subprocess.check_output(["ffmpeg", "-h"], stderr=subprocess.DEVNULL)
return test_case
except Exception:
return unittest.skip("test requires ffmpeg")(test_case)
def bytes_iter(chunk_size, chunks):
for i in range(chunks):
yield bytes(range(i * chunk_size, (i + 1) * chunk_size))
@require_ffmpeg
class AudioUtilsTest(unittest.TestCase):
def test_chunk_bytes_iter_too_big(self):
iter_ = iter(chunk_bytes_iter(bytes_iter(chunk_size=3, chunks=2), 10, stride=(0, 0)))
self.assertEqual(next(iter_), {"raw": b"\x00\x01\x02\x03\x04\x05", "stride": (0, 0)})
with self.assertRaises(StopIteration):
next(iter_)
def test_chunk_bytes_iter(self):
iter_ = iter(chunk_bytes_iter(bytes_iter(chunk_size=3, chunks=2), 3, stride=(0, 0)))
self.assertEqual(next(iter_), {"raw": b"\x00\x01\x02", "stride": (0, 0)})
self.assertEqual(next(iter_), {"raw": b"\x03\x04\x05", "stride": (0, 0)})
with self.assertRaises(StopIteration):
next(iter_)
def test_chunk_bytes_iter_stride(self):
iter_ = iter(chunk_bytes_iter(bytes_iter(chunk_size=3, chunks=2), 3, stride=(1, 1)))
self.assertEqual(next(iter_), {"raw": b"\x00\x01\x02", "stride": (0, 1)})
self.assertEqual(next(iter_), {"raw": b"\x01\x02\x03", "stride": (1, 1)})
self.assertEqual(next(iter_), {"raw": b"\x02\x03\x04", "stride": (1, 1)})
# This is finished, but the chunk_bytes doesn't know it yet.
self.assertEqual(next(iter_), {"raw": b"\x03\x04\x05", "stride": (1, 1)})
self.assertEqual(next(iter_), {"raw": b"\x04\x05", "stride": (1, 0)})
with self.assertRaises(StopIteration):
next(iter_)
def test_chunk_bytes_iter_stride_stream(self):
iter_ = iter(chunk_bytes_iter(bytes_iter(chunk_size=3, chunks=2), 5, stride=(1, 1), stream=True))
self.assertEqual(next(iter_), {"raw": b"\x00\x01\x02", "stride": (0, 0), "partial": True})
self.assertEqual(next(iter_), {"raw": b"\x00\x01\x02\x03\x04", "stride": (0, 1), "partial": False})
self.assertEqual(next(iter_), {"raw": b"\x03\x04\x05", "stride": (1, 0), "partial": False})
with self.assertRaises(StopIteration):
next(iter_)
iter_ = iter(chunk_bytes_iter(bytes_iter(chunk_size=3, chunks=3), 5, stride=(1, 1), stream=True))
self.assertEqual(next(iter_), {"raw": b"\x00\x01\x02", "stride": (0, 0), "partial": True})
self.assertEqual(next(iter_), {"raw": b"\x00\x01\x02\x03\x04", "stride": (0, 1), "partial": False})
self.assertEqual(next(iter_), {"raw": b"\x03\x04\x05\x06\x07", "stride": (1, 1), "partial": False})
self.assertEqual(next(iter_), {"raw": b"\x06\x07\x08", "stride": (1, 0), "partial": False})
with self.assertRaises(StopIteration):
next(iter_)
iter_ = iter(chunk_bytes_iter(bytes_iter(chunk_size=3, chunks=3), 10, stride=(1, 1), stream=True))
self.assertEqual(next(iter_), {"raw": b"\x00\x01\x02", "stride": (0, 0), "partial": True})
self.assertEqual(next(iter_), {"raw": b"\x00\x01\x02\x03\x04\x05", "stride": (0, 0), "partial": True})
self.assertEqual(
next(iter_), {"raw": b"\x00\x01\x02\x03\x04\x05\x06\x07\x08", "stride": (0, 0), "partial": True}
)
self.assertEqual(
next(iter_), {"raw": b"\x00\x01\x02\x03\x04\x05\x06\x07\x08", "stride": (0, 0), "partial": False}
)
with self.assertRaises(StopIteration):
next(iter_)