mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[ci-daily] Fix pipeline tests (#21257)
* use streaming dataset * fix whisper's test * add rescale argument to chunk_iter
This commit is contained in:
parent
275ad9d80a
commit
b80b2218b5
@ -56,7 +56,7 @@ def rescale_stride(stride, ratio):
|
||||
return new_strides
|
||||
|
||||
|
||||
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, ratio, dtype=None):
|
||||
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, rescale=True, dtype=None):
|
||||
inputs_len = inputs.shape[0]
|
||||
step = chunk_len - stride_left - stride_right
|
||||
for i in range(0, inputs_len, step):
|
||||
@ -68,9 +68,15 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
|
||||
_stride_left = 0 if i == 0 else stride_left
|
||||
is_last = i + step + stride_left >= inputs_len
|
||||
_stride_right = 0 if is_last else stride_right
|
||||
|
||||
chunk_len = chunk.shape[0]
|
||||
stride = (chunk_len, _stride_left, _stride_right)
|
||||
if ratio != 1:
|
||||
if "input_features" in processed:
|
||||
processed_len = processed["input_features"].shape[-1]
|
||||
elif "input_values" in processed:
|
||||
processed_len = processed["input_values"].shape[-1]
|
||||
if processed_len != chunk.shape[-1] and rescale:
|
||||
ratio = processed_len / chunk_len
|
||||
stride = rescale_stride([stride], ratio)[0]
|
||||
if chunk.shape[0] > _stride_left:
|
||||
yield {"is_last": is_last, "stride": stride, **processed}
|
||||
@ -101,10 +107,10 @@ def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source
|
||||
sequence = sequence[begin_idx:]
|
||||
|
||||
timestamp_tokens = sequence >= timestamp_begin
|
||||
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
||||
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
||||
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
|
||||
if seq_idx != 0:
|
||||
if seq_idx != 0 and sum(timestamp_tokens) > 0:
|
||||
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
||||
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
||||
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
|
||||
time -= stride_left + stride_right
|
||||
offset = int((time / feature_extractor.sampling_rate) / time_precision)
|
||||
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision)
|
||||
@ -400,13 +406,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
" only 1 version"
|
||||
)
|
||||
forward_params["generate_kwargs"].update(generate_kwargs)
|
||||
if return_timestamps is not None:
|
||||
forward_params["generate_kwargs"]["return_timestamps"] = return_timestamps
|
||||
|
||||
postprocess_params = {}
|
||||
if decoder_kwargs is not None:
|
||||
postprocess_params["decoder_kwargs"] = decoder_kwargs
|
||||
if return_timestamps is not None:
|
||||
forward_params["return_timestamps"] = return_timestamps
|
||||
postprocess_params["return_timestamps"] = return_timestamps
|
||||
if self.model.config.model_type == "whisper":
|
||||
# Whisper is highly specific, if we want timestamps, we need to
|
||||
@ -502,9 +507,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
if chunk_len < stride_left + stride_right:
|
||||
raise ValueError("Chunk length must be superior to stride length")
|
||||
|
||||
rescale = self.type != "seq2seq_whisper"
|
||||
# make sure that
|
||||
for item in chunk_iter(
|
||||
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, align_to, self.torch_dtype
|
||||
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, rescale, self.torch_dtype
|
||||
):
|
||||
yield item
|
||||
else:
|
||||
@ -520,12 +526,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
processed["stride"] = stride
|
||||
yield {"is_last": True, **processed, **extra}
|
||||
|
||||
def _forward(self, model_inputs, generate_kwargs=None):
|
||||
def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
|
||||
if generate_kwargs is None:
|
||||
generate_kwargs = {}
|
||||
|
||||
is_last = model_inputs.pop("is_last")
|
||||
return_timestamps = generate_kwargs.pop("return_timestamps", False)
|
||||
|
||||
if self.type == "seq2seq":
|
||||
encoder = self.model.get_encoder()
|
||||
@ -635,9 +640,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
# Simply cast from pyctcdecode format to wav2vec2 format to leverage
|
||||
# pre-existing code later
|
||||
chunk_offset = beams[0][2]
|
||||
word_offsets = []
|
||||
offsets = []
|
||||
for word, (start_offset, end_offset) in chunk_offset:
|
||||
word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
|
||||
offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
|
||||
else:
|
||||
skip_special_tokens = self.type != "ctc"
|
||||
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
|
||||
|
@ -201,8 +201,9 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
@require_torch
|
||||
@require_pyctcdecode
|
||||
def test_large_model_pt_with_lm(self):
|
||||
dataset = load_dataset("Narsil/asr_dummy")
|
||||
filename = dataset["test"][3]["file"]
|
||||
dataset = load_dataset("Narsil/asr_dummy", streaming=True)
|
||||
third_item = next(iter(dataset["test"].skip(3)))
|
||||
filename = third_item["file"]
|
||||
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
|
Loading…
Reference in New Issue
Block a user