Adding timestamps for CTC with LM in ASR pipeline. (#15863)

* Adding timestamps for CTC with LM in ASR pipeline.

* iRemove print.

* Nit change.
This commit is contained in:
Nicolas Patry 2022-03-02 10:49:05 +01:00 committed by GitHub
parent 8a133490bf
commit 6e57a56987
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 61 additions and 24 deletions

View File

@ -353,7 +353,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
word = char
last_state = state
if state == "WORD":
if last_state == "WORD":
word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
return word_offsets

View File

@ -313,8 +313,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# Optional return types
optional = {}
if return_timestamps and self.type != "ctc":
if return_timestamps and self.type == "seq2seq":
raise ValueError("We cannot return_timestamps yet on non-ctc models !")
if return_timestamps == "char" and self.type == "ctc_with_lm":
raise ValueError("CTC with LM cannot return `char` timestamps, only `words`")
final_items = []
key = "logits" if self.type == "ctc_with_lm" else "tokens"
@ -335,34 +337,43 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if self.type == "ctc_with_lm":
if decoder_kwargs is None:
decoder_kwargs = {}
text = self.decoder.decode_beams(items, **decoder_kwargs)[0][0]
beams = self.decoder.decode_beams(items, **decoder_kwargs)
text = beams[0][0]
if return_timestamps:
# Simply cast from pyctcdecode format to wav2vec2 format to leverage
# pre-existing code later
chunk_offset = beams[0][2]
word_offsets = []
for word, (start_offset, end_offset) in chunk_offset:
word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
else:
skip_special_tokens = self.type != "ctc"
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
if return_timestamps:
if return_timestamps == "char":
decoded = self.tokenizer.decode(
items, skip_special_tokens=skip_special_tokens, output_char_offsets=True
char_offsets = self.tokenizer.decode(
items, skip_special_tokens=skip_special_tokens, output_char_offsets=True
)["char_offsets"]
if return_timestamps == "word":
word_offsets = self.tokenizer._get_word_offsets(
char_offsets, self.tokenizer.replace_word_delimiter_char
)
elif return_timestamps == "word":
decoded = self.tokenizer.decode(
items, skip_special_tokens=skip_special_tokens, output_word_offsets=True
)
chunks = []
for item in decoded[f"{return_timestamps}_offsets"]:
start = (
item["start_offset"]
* self.model.config.inputs_to_logits_ratio
/ self.feature_extractor.sampling_rate
)
stop = (
item["end_offset"]
* self.model.config.inputs_to_logits_ratio
/ self.feature_extractor.sampling_rate
)
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
optional["chunks"] = chunks
if return_timestamps:
if return_timestamps == "word":
offsets = word_offsets
else:
offsets = char_offsets
chunks = []
for item in offsets:
start = item["start_offset"] * self.model.config.inputs_to_logits_ratio
start /= self.feature_extractor.sampling_rate
stop = item["end_offset"] * self.model.config.inputs_to_logits_ratio
stop /= self.feature_extractor.sampling_rate
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
optional["chunks"] = chunks
extra = defaultdict(list)
for output in model_outputs:

View File

@ -188,6 +188,32 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
},
)
speech_recognizer.type = "ctc_with_lm"
# Simple test with CTC with LM, chunking + timestamps
output = speech_recognizer(filename, chunk_length_s=2.0, return_timestamps="word")
self.assertEqual(
output,
{
"text": "y en las ramas medio sumergidas revoloteaban algunos pájaros de quimérico y legendario plumajcri",
"chunks": [
{"text": "y", "timestamp": (0.52, 0.54)},
{"text": "en", "timestamp": (0.6, 0.68)},
{"text": "las", "timestamp": (0.74, 0.84)},
{"text": "ramas", "timestamp": (0.94, 1.24)},
{"text": "medio", "timestamp": (1.32, 1.52)},
{"text": "sumergidas", "timestamp": (1.56, 2.22)},
{"text": "revoloteaban", "timestamp": (2.36, 3.0)},
{"text": "algunos", "timestamp": (3.06, 3.38)},
{"text": "pájaros", "timestamp": (3.46, 3.86)},
{"text": "de", "timestamp": (3.92, 4.0)},
{"text": "quimérico", "timestamp": (4.08, 4.6)},
{"text": "y", "timestamp": (4.66, 4.68)},
{"text": "legendario", "timestamp": (4.74, 5.26)},
{"text": "plumajcri", "timestamp": (5.34, 5.74)},
],
},
)
@require_tf
def test_small_model_tf(self):
self.skipTest("Tensorflow not supported yet.")