Add time stamps for wav2vec2 with lm (#15854)

* [Wav2Vec2 With LM] add timestamps

* correct

* correct

* Apply suggestions from code review

* correct

* Update src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py

* make style

* Update src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* make style

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen 2022-03-01 17:03:05 +01:00 committed by GitHub
parent 3f2e636850
commit e064f08150
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 215 additions and 16 deletions

View File

@ -97,6 +97,8 @@ WAV2VEC2_KWARGS_DOCSTRING = r"""
Whether or not to print more information and warnings.
"""
ListOfDict = List[Dict[str, Union[int, str]]]
@dataclass
class Wav2Vec2CTCTokenizerOutput(ModelOutput):
@ -106,18 +108,18 @@ class Wav2Vec2CTCTokenizerOutput(ModelOutput):
Args:
text (list of `str` or `str`):
Decoded logits in text from. Usually the speech transcription.
char_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`):
char_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`):
Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char
offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with
produced text.
word_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`):
word_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`):
Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets
can be used to compute time stamps for each word.
"""
text: Union[List[str], str]
char_offsets: List[Dict[str, Union[float, str]]] = None
word_offsets: List[Dict[str, Union[float, str]]] = None
char_offsets: Union[List[ListOfDict], ListOfDict] = None
word_offsets: Union[List[ListOfDict], ListOfDict] = None
class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):

View File

@ -66,6 +66,9 @@ PRETRAINED_VOCAB_FILES_MAP = {
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"facebook/wav2vec2-lv-60-espeak-cv-ft": sys.maxsize}
ListOfDict = List[Dict[str, Union[int, str]]]
@dataclass
class Wav2Vec2PhonemeCTCTokenizerOutput(ModelOutput):
"""
@ -74,14 +77,14 @@ class Wav2Vec2PhonemeCTCTokenizerOutput(ModelOutput):
Args:
text (list of `str` or `str`):
Decoded logits in text from. Usually the speech transcription.
char_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`):
char_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`):
Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char
offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with
produced text.
"""
text: Union[List[str], str]
char_offsets: List[Dict[str, Union[float, str]]] = None
char_offsets: Union[List[ListOfDict], ListOfDict] = None
class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):

View File

@ -19,7 +19,7 @@ import os
from contextlib import contextmanager
from dataclasses import dataclass
from multiprocessing import get_context
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union
import numpy as np
@ -34,23 +34,30 @@ if TYPE_CHECKING:
from ...tokenization_utils import PreTrainedTokenizerBase
ListOfDict = List[Dict[str, Union[int, str]]]
@dataclass
class Wav2Vec2DecoderWithLMOutput(ModelOutput):
"""
Output type of [`Wav2Vec2DecoderWithLM`], with transcription.
Args:
text (list of `str`):
text (list of `str` or `str`):
Decoded logits in text from. Usually the speech transcription.
logit_score (list of `float`):
logit_score (list of `float` or `float`):
Total logit score of the beam associated with produced text.
lm_score (list of `float`):
Fused lm_score of the beam associated with produced text.
word_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`):
Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets
can be used to compute time stamps for each word.
"""
text: Union[List[str], str]
logit_score: Union[List[float], float] = None
lm_score: Union[List[float], float] = None
word_offsets: Union[List[ListOfDict], ListOfDict] = None
class Wav2Vec2ProcessorWithLM(ProcessorMixin):
@ -232,6 +239,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
beta: Optional[float] = None,
unk_score_offset: Optional[float] = None,
lm_score_boundary: Optional[bool] = None,
output_word_offsets: bool = False,
):
"""
Batch decode output logits to audio transcription with language model support.
@ -267,6 +275,18 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
Amount of log score offset for unknown tokens
lm_score_boundary (`bool`, *optional*):
Whether to have kenlm respect boundaries when scoring
output_word_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate
and model downsampling rate to compute the time-stamps of transcribed words.
<Tip>
Please take a look at the Example of [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to
better understand how to make use of `output_word_offsets`.
[`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.batch_decode`] works the same way with batched
output.
</Tip>
Returns:
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
@ -310,13 +330,18 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
pool.close()
# extract text and scores
batch_texts, logit_scores, lm_scores = [], [], []
batch_texts, logit_scores, lm_scores, word_offsets = [], [], [], []
for d in decoded_beams:
batch_texts.append(d[0][0])
logit_scores.append(d[0][-2])
lm_scores.append(d[0][-1])
# more output features will be added in the future
return Wav2Vec2DecoderWithLMOutput(text=batch_texts, logit_score=logit_scores, lm_score=lm_scores)
word_offsets.append([{"word": t[0], "start_offset": t[1][0], "end_offset": t[1][1]} for t in d[0][1]])
word_offsets = word_offsets if output_word_offsets else None
return Wav2Vec2DecoderWithLMOutput(
text=batch_texts, logit_score=logit_scores, lm_score=lm_scores, word_offsets=word_offsets
)
def decode(
self,
@ -330,6 +355,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
beta: Optional[float] = None,
unk_score_offset: Optional[float] = None,
lm_score_boundary: Optional[bool] = None,
output_word_offsets: bool = False,
):
"""
Decode output logits to audio transcription with language model support.
@ -357,11 +383,65 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
Amount of log score offset for unknown tokens
lm_score_boundary (`bool`, *optional*):
Whether to have kenlm respect boundaries when scoring
output_word_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate
and model downsampling rate to compute the time-stamps of transcribed words.
<Tip>
Please take a look at the example of [`~models.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to
better understand how to make use of `output_word_offsets`.
</Tip>
Returns:
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
"""
Example:
```python
>>> # Let's see how to retrieve time steps for a model
>>> from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC
>>> from datasets import load_dataset
>>> import datasets
>>> import torch
>>> # import model, feature extractor, tokenizer
>>> model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
>>> processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
>>> # load first sample of English common_voice
>>> dataset = load_dataset("common_voice", "en", split="train", streaming=True)
>>> dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000))
>>> dataset_iter = iter(dataset)
>>> sample = next(dataset_iter)
>>> # forward sample through model to get greedily predicted transcription ids
>>> input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values
>>> with torch.no_grad():
... logits = model(input_values).logits[0].cpu().numpy()
>>> # retrieve word stamps (analogous commands for `output_char_offsets`)
>>> outputs = tokenizer.decode(logits, output_word_offsets=True)
>>> # compute `time_offset` in seconds as product of downsampling ratio and sampling_rate
>>> time_offset = model.config.inputs_to_logits_ratio / feature_extractor.sampling_rate
>>> word_offsets = [
... {
... "word": d["word"],
... "start_time": d["start_offset"] * time_offset,
... "end_time": d["end_offset"] * time_offset,
... }
... for d in outputs.word_offsets
... ]
>>> # compare word offsets with audio `common_voice_en_100038.mp3` online on the dataset viewer:
>>> # https://huggingface.co/datasets/common_voice/viewer/en/train
>>> word_offset
>>> # [{'word': 'WHY', 'start_time': 1.42, 'end_time': 1.54}, {'word': 'DOES',
>>> # 'start_time': 1.64, 'end_time': 1.88}, {'word': 'A',
>>> # 'start_time': 2.12, 'end_time': 2.14}, {'word': 'MILE', 'start_time': 2.26, 'end_time': 2.46}, ...
```"""
from pyctcdecode.constants import (
DEFAULT_BEAM_WIDTH,
DEFAULT_HOTWORD_WEIGHT,
@ -390,9 +470,19 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
hotword_weight=hotword_weight,
)
word_offsets = None
if output_word_offsets:
word_offsets = [
{"word": word, "start_offset": start_offset, "end_offset": end_offset}
for word, (start_offset, end_offset) in decoded_beams[0][2]
]
# more output features will be added in the future
return Wav2Vec2DecoderWithLMOutput(
text=decoded_beams[0][0], logit_score=decoded_beams[0][-2], lm_score=decoded_beams[0][-1]
text=decoded_beams[0][0],
logit_score=decoded_beams[0][-2],
lm_score=decoded_beams[0][-1],
word_offsets=word_offsets,
)
@contextmanager

View File

@ -20,13 +20,15 @@ import unittest
from multiprocessing import get_context
from pathlib import Path
import datasets
import numpy as np
from datasets import load_dataset
from transformers import AutoProcessor
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available, is_torch_available
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
from transformers.testing_utils import require_pyctcdecode
from transformers.testing_utils import require_pyctcdecode, require_torch, require_torchaudio, slow
from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list
@ -35,6 +37,10 @@ if is_pyctcdecode_available():
from huggingface_hub import snapshot_download
from pyctcdecode import BeamSearchDecoderCTC
from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
from transformers.models.wav2vec2_with_lm.processing_wav2vec2_with_lm import Wav2Vec2DecoderWithLMOutput
if is_torch_available():
from transformers import Wav2Vec2ForCTC
@require_pyctcdecode
@ -350,3 +356,101 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
decoded_auto = processor_auto.batch_decode(logits)
self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text)
@staticmethod
def get_from_offsets(offsets, key):
retrieved_list = [d[key] for d in offsets]
return retrieved_list
def test_offsets_integration_fast(self):
processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
logits = self._get_dummy_logits()[0]
outputs = processor.decode(logits, output_word_offsets=True)
# check Wav2Vec2CTCTokenizerOutput keys for word
self.assertTrue(len(outputs.keys()), 2)
self.assertTrue("text" in outputs)
self.assertTrue("word_offsets" in outputs)
self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput))
self.assertEqual(" ".join(self.get_from_offsets(outputs["word_offsets"], "word")), outputs.text)
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "word"), ["<s>", "<s>", "</s>"])
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "start_offset"), [0, 2, 4])
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "end_offset"), [1, 3, 5])
def test_offsets_integration_fast_batch(self):
processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
logits = self._get_dummy_logits()
outputs = processor.batch_decode(logits, output_word_offsets=True)
# check Wav2Vec2CTCTokenizerOutput keys for word
self.assertTrue(len(outputs.keys()), 2)
self.assertTrue("text" in outputs)
self.assertTrue("word_offsets" in outputs)
self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput))
self.assertListEqual(
[" ".join(self.get_from_offsets(o, "word")) for o in outputs["word_offsets"]], outputs.text
)
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "word"), ["<s>", "<s>", "</s>"])
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "start_offset"), [0, 2, 4])
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "end_offset"), [1, 3, 5])
@slow
@require_torch
@require_torchaudio
def test_word_time_stamp_integration(self):
import torch
ds = load_dataset("common_voice", "en", split="train", streaming=True)
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
ds_iter = iter(ds)
sample = next(ds_iter)
processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
# compare to filename `common_voice_en_100038.mp3` of dataset viewer on https://huggingface.co/datasets/common_voice/viewer/en/train
input_values = processor(sample["audio"]["array"], return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values).logits.cpu().numpy()
output = processor.decode(logits[0], output_word_offsets=True)
time_offset = model.config.inputs_to_logits_ratio / processor.feature_extractor.sampling_rate
word_time_stamps = [
{
"start_time": d["start_offset"] * time_offset,
"end_time": d["end_offset"] * time_offset,
"word": d["word"],
}
for d in output["word_offsets"]
]
EXPECTED_TEXT = "WHY DOES A MILE SANDRA LOOK LIKE SHE WANTS TO CONSUME JOHN SNOW ON THE RIVER AT THE WALL"
# output words
self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), EXPECTED_TEXT)
self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), output.text)
# output times
start_times = [round(x, 2) for x in self.get_from_offsets(word_time_stamps, "start_time")]
end_times = [round(x, 2) for x in self.get_from_offsets(word_time_stamps, "end_time")]
# fmt: off
self.assertListEqual(
start_times,
[
1.42, 1.64, 2.12, 2.26, 2.54, 3.0, 3.24, 3.6, 3.8, 4.1, 4.26, 4.94, 5.28, 5.66, 5.78, 5.94, 6.32, 6.54, 6.66,
],
)
self.assertListEqual(
end_times,
[
1.54, 1.88, 2.14, 2.46, 2.9, 3.18, 3.54, 3.72, 4.02, 4.18, 4.76, 5.16, 5.56, 5.7, 5.86, 6.2, 6.38, 6.62, 6.94,
],
)
# fmt: on