[ci-daily] Fix pipeline tests (#21257)

* use streaming dataset

* fix whisper's test

* add rescale argument to chunk_iter
This commit is contained in:
Arthur 2023-01-23 19:32:49 +01:00 committed by GitHub
parent 275ad9d80a
commit b80b2218b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 15 deletions

View File

@ -56,7 +56,7 @@ def rescale_stride(stride, ratio):
return new_strides
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, ratio, dtype=None):
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, rescale=True, dtype=None):
inputs_len = inputs.shape[0]
step = chunk_len - stride_left - stride_right
for i in range(0, inputs_len, step):
@ -68,9 +68,15 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
_stride_left = 0 if i == 0 else stride_left
is_last = i + step + stride_left >= inputs_len
_stride_right = 0 if is_last else stride_right
chunk_len = chunk.shape[0]
stride = (chunk_len, _stride_left, _stride_right)
if ratio != 1:
if "input_features" in processed:
processed_len = processed["input_features"].shape[-1]
elif "input_values" in processed:
processed_len = processed["input_values"].shape[-1]
if processed_len != chunk.shape[-1] and rescale:
ratio = processed_len / chunk_len
stride = rescale_stride([stride], ratio)[0]
if chunk.shape[0] > _stride_left:
yield {"is_last": is_last, "stride": stride, **processed}
@ -101,10 +107,10 @@ def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source
sequence = sequence[begin_idx:]
timestamp_tokens = sequence >= timestamp_begin
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
last_timestamp = np.where(timestamp_tokens)[0][-1]
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
if seq_idx != 0:
if seq_idx != 0 and sum(timestamp_tokens) > 0:
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
last_timestamp = np.where(timestamp_tokens)[0][-1]
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
time -= stride_left + stride_right
offset = int((time / feature_extractor.sampling_rate) / time_precision)
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision)
@ -400,13 +406,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
" only 1 version"
)
forward_params["generate_kwargs"].update(generate_kwargs)
if return_timestamps is not None:
forward_params["generate_kwargs"]["return_timestamps"] = return_timestamps
postprocess_params = {}
if decoder_kwargs is not None:
postprocess_params["decoder_kwargs"] = decoder_kwargs
if return_timestamps is not None:
forward_params["return_timestamps"] = return_timestamps
postprocess_params["return_timestamps"] = return_timestamps
if self.model.config.model_type == "whisper":
# Whisper is highly specific, if we want timestamps, we need to
@ -502,9 +507,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if chunk_len < stride_left + stride_right:
raise ValueError("Chunk length must be superior to stride length")
rescale = self.type != "seq2seq_whisper"
# make sure that
for item in chunk_iter(
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, align_to, self.torch_dtype
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, rescale, self.torch_dtype
):
yield item
else:
@ -520,12 +526,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
processed["stride"] = stride
yield {"is_last": True, **processed, **extra}
def _forward(self, model_inputs, generate_kwargs=None):
def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
if generate_kwargs is None:
generate_kwargs = {}
is_last = model_inputs.pop("is_last")
return_timestamps = generate_kwargs.pop("return_timestamps", False)
if self.type == "seq2seq":
encoder = self.model.get_encoder()
@ -635,9 +640,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# Simply cast from pyctcdecode format to wav2vec2 format to leverage
# pre-existing code later
chunk_offset = beams[0][2]
word_offsets = []
offsets = []
for word, (start_offset, end_offset) in chunk_offset:
word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
else:
skip_special_tokens = self.type != "ctc"
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)

View File

@ -201,8 +201,9 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
@require_torch
@require_pyctcdecode
def test_large_model_pt_with_lm(self):
dataset = load_dataset("Narsil/asr_dummy")
filename = dataset["test"][3]["file"]
dataset = load_dataset("Narsil/asr_dummy", streaming=True)
third_item = next(iter(dataset["test"].skip(3)))
filename = third_item["file"]
speech_recognizer = pipeline(
task="automatic-speech-recognition",