mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Hotfix chunk_length_s
instead of _ms
. (#15029)
* Hotfix `chunk_length_s` instead of `_ms`. * Adding fix of `pad_token` which should be last/previous token for CTC proper decoding * Fixing ChunkPipeline unwrapping. * Adding a PackIterator specific test.
This commit is contained in:
parent
21aecc0971
commit
19d37c2dd3
@ -66,6 +66,39 @@ def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
|
||||
return audio
|
||||
|
||||
|
||||
def apply_stride(tokens, stride):
|
||||
max_token_n = tokens.shape[-1]
|
||||
max_input_n = max(input_n for input_n, _, _ in stride)
|
||||
ratio = max_token_n / max_input_n
|
||||
for i, (input_n, left, right) in enumerate(stride):
|
||||
token_n = int(round(input_n * ratio))
|
||||
left_token = int(round(left / input_n * token_n))
|
||||
right_token = int(round((input_n - right) / input_n * token_n))
|
||||
# This is CTC to preseve decoding, we need to duplicate
|
||||
# next letter, and last letter
|
||||
|
||||
first_letter = tokens[i, left_token]
|
||||
tokens[i, :left_token] = first_letter
|
||||
|
||||
last_letter = tokens[i, right_token - 1]
|
||||
tokens[i, right_token:] = last_letter
|
||||
|
||||
|
||||
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right):
|
||||
inputs_len = inputs.shape[0]
|
||||
step = chunk_len - stride_left - stride_right
|
||||
for i in range(0, inputs_len, step):
|
||||
# add start and end paddings to the chunk
|
||||
chunk = inputs[i : i + chunk_len]
|
||||
processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
|
||||
_stride_left = 0 if i == 0 else stride_left
|
||||
is_last = i + step >= inputs_len
|
||||
_stride_right = 0 if is_last else stride_right
|
||||
|
||||
if chunk.shape[0] > _stride_left:
|
||||
yield {"is_last": is_last, "stride": (chunk.shape[0], _stride_left, _stride_right), **processed}
|
||||
|
||||
|
||||
class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
"""
|
||||
Pipeline that aims at extracting spoken text contained within some audio.
|
||||
@ -85,11 +118,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
tokenizer ([`PreTrainedTokenizer`]):
|
||||
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
|
||||
[`PreTrainedTokenizer`].
|
||||
chunk_length_ms (`int`, *optional*, defaults to 0):
|
||||
chunk_length_s (`float`, *optional*, defaults to 0):
|
||||
The input length for in each chunk. If `0` then chunking is disabled (default). Only available for CTC
|
||||
models.
|
||||
stride_length_ms (`int`, *optional*, defaults to `chunk_length_ms / 6`):
|
||||
The length of stride on the left and right of each chunk. Used only with `chunk_length_ms > 0`. This
|
||||
stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`):
|
||||
The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This
|
||||
enables the model to *see* more context and infer letters better than without this context but the
|
||||
pipeline discards the stride bits at the end to make the final reconstitution as perfect as possible.
|
||||
framework (`str`, *optional*):
|
||||
@ -111,6 +144,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
|
||||
|
||||
self.check_model_type(dict(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.items() + MODEL_FOR_CTC_MAPPING.items()))
|
||||
self.is_ctc = self.model.__class__ in MODEL_FOR_CTC_MAPPING.values()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -139,13 +173,13 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
# No parameters on this pipeline right now
|
||||
preprocess_params = {}
|
||||
if "chunk_length_ms" in kwargs:
|
||||
preprocess_params["chunk_length_ms"] = kwargs["chunk_length_ms"]
|
||||
if "stride_length_ms" in kwargs:
|
||||
preprocess_params["stride_length_ms"] = kwargs["stride_length_ms"]
|
||||
if "chunk_length_s" in kwargs:
|
||||
preprocess_params["chunk_length_s"] = kwargs["chunk_length_s"]
|
||||
if "stride_length_s" in kwargs:
|
||||
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
|
||||
return preprocess_params, {}, {}
|
||||
|
||||
def preprocess(self, inputs, chunk_length_ms=0, stride_length_ms=None):
|
||||
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
|
||||
if isinstance(inputs, str):
|
||||
with open(inputs, "rb") as f:
|
||||
inputs = f.read()
|
||||
@ -158,39 +192,28 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
if len(inputs.shape) != 1:
|
||||
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
|
||||
|
||||
if chunk_length_ms:
|
||||
if stride_length_ms is None:
|
||||
stride_length_ms = chunk_length_ms // 6
|
||||
inputs_len = len(inputs)
|
||||
chunk_len = chunk_length_ms * self.feature_extractor.sampling_rate // 1000
|
||||
stride_len = stride_length_ms * self.feature_extractor.sampling_rate // 1000
|
||||
if chunk_length_s:
|
||||
if stride_length_s is None:
|
||||
stride_length_s = chunk_length_s / 6
|
||||
|
||||
# Redefine chunk_len to useful chunk length
|
||||
# Not the size
|
||||
# chunk_len = chunk_len - 2 * stride_len
|
||||
chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate))
|
||||
|
||||
if self.model.__class__ not in MODEL_FOR_CTC_MAPPING.values():
|
||||
if isinstance(stride_length_s, (int, float)):
|
||||
stride_length_s = [stride_length_s, stride_length_s]
|
||||
|
||||
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate))
|
||||
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate))
|
||||
|
||||
if not self.is_ctc:
|
||||
raise ValueError(
|
||||
"`chunk_length_ms` is only valid for CTC models, use other chunking options for other models"
|
||||
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models"
|
||||
)
|
||||
if chunk_len < stride_len:
|
||||
if chunk_len < stride_left + stride_right:
|
||||
raise ValueError("Chunk length must be superior to stride length")
|
||||
|
||||
# make sure that
|
||||
step = chunk_len
|
||||
for i in range(0, inputs_len, step):
|
||||
# add start and end paddings to the chunk
|
||||
start = 0 if i - stride_len < 0 else i - stride_len
|
||||
stop = inputs_len if i + chunk_len + stride_len > inputs_len else i + chunk_len + stride_len
|
||||
chunk = inputs[start:stop]
|
||||
processed = self.feature_extractor(
|
||||
chunk, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
||||
)
|
||||
stride_left = i - start
|
||||
stride_right = max(stop - (i + chunk_len), 0)
|
||||
is_last = i + step > inputs_len
|
||||
|
||||
yield {"is_last": is_last, "stride": (stop - start, stride_left, stride_right), **processed}
|
||||
for item in chunk_iter(inputs, self.feature_extractor, chunk_len, stride_left, stride_right):
|
||||
yield item
|
||||
else:
|
||||
processed = self.feature_extractor(
|
||||
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
||||
@ -198,8 +221,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
yield {"is_last": True, **processed}
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
model_class = self.model.__class__
|
||||
is_last = model_inputs.pop("is_last")
|
||||
model_class = self.model.__class__
|
||||
if model_class in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
|
||||
encoder = self.model.get_encoder()
|
||||
# we need to pass `processed.get("attention_mask")` here since audio encoder
|
||||
@ -217,15 +240,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
if isinstance(stride, tuple):
|
||||
stride = [stride]
|
||||
|
||||
max_token_n = tokens.shape[-1]
|
||||
max_input_n = max(input_n for input_n, _, _ in stride)
|
||||
ratio = max_token_n / max_input_n
|
||||
for i, (input_n, left, right) in enumerate(stride):
|
||||
token_n = int(input_n * ratio) + 1
|
||||
left_token = int(left / input_n * token_n)
|
||||
right_token = int((input_n - right) / input_n * token_n) + 1
|
||||
tokens[i, :left_token] = self.tokenizer.pad_token_id
|
||||
tokens[i, right_token:] = self.tokenizer.pad_token_id
|
||||
apply_stride(tokens, stride)
|
||||
else:
|
||||
logger.warning("This is an unknown class, treating it as CTC.")
|
||||
outputs = self.model(**model_inputs)
|
||||
|
@ -276,7 +276,7 @@ class PipelinePackIterator(PipelineIterator):
|
||||
else:
|
||||
item = processed
|
||||
is_last = item.pop("is_last")
|
||||
accumulator.append(item)
|
||||
accumulator.append(item)
|
||||
return accumulator
|
||||
|
||||
|
||||
|
@ -27,11 +27,24 @@ from transformers import (
|
||||
Wav2Vec2ForCTC,
|
||||
)
|
||||
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
|
||||
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, require_torchaudio, slow
|
||||
from transformers.pipelines.automatic_speech_recognition import apply_stride, chunk_iter
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
is_torch_available,
|
||||
nested_simplify,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_torchaudio,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
# We can't use this mixin because it assumes TF support.
|
||||
# from .test_pipelines_common import CustomInputPipelineCommonMixin
|
||||
|
||||
@ -245,17 +258,119 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
framework="pt",
|
||||
chunk_length_ms=10_000,
|
||||
chunk_length_s=10.0,
|
||||
)
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
audio = ds[40]["audio"]["array"]
|
||||
|
||||
n_repeats = 100
|
||||
n_repeats = 10
|
||||
audio = np.tile(audio, n_repeats)
|
||||
output = speech_recognizer([audio], batch_size=2)
|
||||
expected_text = "A MAN SAID TO THE UNIVERSE SIR I EXIST " * n_repeats
|
||||
expected = [{"text": expected_text.strip()}]
|
||||
self.assertEqual(output, expected)
|
||||
|
||||
@require_torch
|
||||
def test_chunk_iterator(self):
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
inputs = torch.arange(100).long()
|
||||
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 100, 0, 0))
|
||||
self.assertEqual(len(outs), 1)
|
||||
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
|
||||
self.assertEqual([o["is_last"] for o in outs], [True])
|
||||
|
||||
# two chunks no stride
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 50, 0, 0))
|
||||
self.assertEqual(len(outs), 2)
|
||||
self.assertEqual([o["stride"] for o in outs], [(50, 0, 0), (50, 0, 0)])
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 50), (1, 50)])
|
||||
self.assertEqual([o["is_last"] for o in outs], [False, True])
|
||||
|
||||
# two chunks incomplete last
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 80, 0, 0))
|
||||
self.assertEqual(len(outs), 2)
|
||||
self.assertEqual([o["stride"] for o in outs], [(80, 0, 0), (20, 0, 0)])
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 20)])
|
||||
self.assertEqual([o["is_last"] for o in outs], [False, True])
|
||||
|
||||
@require_torch
|
||||
def test_chunk_iterator_stride(self):
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
inputs = torch.arange(100).long()
|
||||
input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[
|
||||
"input_values"
|
||||
]
|
||||
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 100, 20, 10))
|
||||
self.assertEqual(len(outs), 2)
|
||||
self.assertEqual([o["stride"] for o in outs], [(100, 0, 10), (30, 20, 0)])
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100), (1, 30)])
|
||||
self.assertEqual([o["is_last"] for o in outs], [False, True])
|
||||
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 80, 20, 10))
|
||||
self.assertEqual(len(outs), 2)
|
||||
self.assertEqual([o["stride"] for o in outs], [(80, 0, 10), (50, 20, 0)])
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 50)])
|
||||
self.assertEqual([o["is_last"] for o in outs], [False, True])
|
||||
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 90, 20, 0))
|
||||
self.assertEqual(len(outs), 2)
|
||||
self.assertEqual([o["stride"] for o in outs], [(90, 0, 0), (30, 20, 0)])
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 90), (1, 30)])
|
||||
|
||||
inputs = torch.LongTensor([i % 2 for i in range(100)])
|
||||
input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[
|
||||
"input_values"
|
||||
]
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 30, 5, 5))
|
||||
self.assertEqual(len(outs), 5)
|
||||
self.assertEqual([o["stride"] for o in outs], [(30, 0, 5), (30, 5, 5), (30, 5, 5), (30, 5, 5), (20, 5, 0)])
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 30), (1, 30), (1, 30), (1, 30), (1, 20)])
|
||||
self.assertEqual([o["is_last"] for o in outs], [False, False, False, False, True])
|
||||
# (0, 25)
|
||||
self.assertEqual(nested_simplify(input_values[:, :30]), nested_simplify(outs[0]["input_values"]))
|
||||
# (25, 45)
|
||||
self.assertEqual(nested_simplify(input_values[:, 20:50]), nested_simplify(outs[1]["input_values"]))
|
||||
# (45, 65)
|
||||
self.assertEqual(nested_simplify(input_values[:, 40:70]), nested_simplify(outs[2]["input_values"]))
|
||||
# (65, 85)
|
||||
self.assertEqual(nested_simplify(input_values[:, 60:90]), nested_simplify(outs[3]["input_values"]))
|
||||
# (85, 100)
|
||||
self.assertEqual(nested_simplify(input_values[:, 80:100]), nested_simplify(outs[4]["input_values"]))
|
||||
|
||||
|
||||
@require_torch
|
||||
class ApplyStrideTest(unittest.TestCase):
|
||||
def test_apply_stride(self):
|
||||
tokens = torch.arange(10).long().reshape((2, 5))
|
||||
|
||||
# No stride
|
||||
apply_stride(tokens, [(100, 0, 0), (100, 0, 0)])
|
||||
|
||||
expected = torch.arange(10).long().reshape((2, 5))
|
||||
self.assertEqual(expected.tolist(), tokens.tolist())
|
||||
|
||||
def test_apply_stride_real_stride(self):
|
||||
# Stride aligned
|
||||
tokens = torch.arange(10).long().reshape((2, 5))
|
||||
apply_stride(tokens, [(100, 20, 0), (100, 0, 20)])
|
||||
self.assertEqual([[1, 1, 2, 3, 4], [5, 6, 7, 8, 8]], tokens.tolist())
|
||||
|
||||
# Stride rounded
|
||||
tokens = torch.arange(10).long().reshape((2, 5))
|
||||
apply_stride(tokens, [(100, 15, 0), (100, 0, 15)])
|
||||
self.assertEqual([[1, 1, 2, 3, 4], [5, 6, 7, 8, 8]], tokens.tolist())
|
||||
|
||||
# No stride rounded
|
||||
tokens = torch.arange(10).long().reshape((2, 5))
|
||||
apply_stride(tokens, [(100, 5, 0), (100, 0, 5)])
|
||||
self.assertEqual([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], tokens.tolist())
|
||||
|
||||
def test_apply_stride_with_padding(self):
|
||||
# Stride aligned
|
||||
tokens = torch.arange(10).long().reshape((2, 5))
|
||||
apply_stride(tokens, [(100, 20, 0), (60, 0, 20)])
|
||||
self.assertEqual([[1, 1, 2, 3, 4], [5, 6, 6, 6, 6]], tokens.tolist())
|
||||
|
@ -584,3 +584,14 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
|
||||
outputs = [item for item in dataset]
|
||||
self.assertEqual(outputs, [[{"id": 2}, {"id": 3}], [{"id": 4}, {"id": 5}]])
|
||||
|
||||
# is_false Across batch
|
||||
dummy_dataset = [{"id": [0, 1, 2], "is_last": [False, False, False]}, {"id": [3], "is_last": [True]}]
|
||||
|
||||
def add(number, extra=0):
|
||||
return {"id": [i + extra for i in number["id"]], "is_last": number["is_last"]}
|
||||
|
||||
dataset = PipelinePackIterator(dummy_dataset, add, {"extra": 2}, loader_batch_size=3)
|
||||
|
||||
outputs = [item for item in dataset]
|
||||
self.assertEqual(outputs, [[{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]])
|
||||
|
Loading…
Reference in New Issue
Block a user