mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Add time stamps for wav2vec2 with lm (#15854)
* [Wav2Vec2 With LM] add timestamps * correct * correct * Apply suggestions from code review * correct * Update src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py * make style * Update src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * make style * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
3f2e636850
commit
e064f08150
@ -97,6 +97,8 @@ WAV2VEC2_KWARGS_DOCSTRING = r"""
|
||||
Whether or not to print more information and warnings.
|
||||
"""
|
||||
|
||||
ListOfDict = List[Dict[str, Union[int, str]]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Wav2Vec2CTCTokenizerOutput(ModelOutput):
|
||||
@ -106,18 +108,18 @@ class Wav2Vec2CTCTokenizerOutput(ModelOutput):
|
||||
Args:
|
||||
text (list of `str` or `str`):
|
||||
Decoded logits in text from. Usually the speech transcription.
|
||||
char_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`):
|
||||
char_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`):
|
||||
Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char
|
||||
offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with
|
||||
produced text.
|
||||
word_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`):
|
||||
word_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`):
|
||||
Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets
|
||||
can be used to compute time stamps for each word.
|
||||
"""
|
||||
|
||||
text: Union[List[str], str]
|
||||
char_offsets: List[Dict[str, Union[float, str]]] = None
|
||||
word_offsets: List[Dict[str, Union[float, str]]] = None
|
||||
char_offsets: Union[List[ListOfDict], ListOfDict] = None
|
||||
word_offsets: Union[List[ListOfDict], ListOfDict] = None
|
||||
|
||||
|
||||
class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||
|
@ -66,6 +66,9 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"facebook/wav2vec2-lv-60-espeak-cv-ft": sys.maxsize}
|
||||
|
||||
|
||||
ListOfDict = List[Dict[str, Union[int, str]]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Wav2Vec2PhonemeCTCTokenizerOutput(ModelOutput):
|
||||
"""
|
||||
@ -74,14 +77,14 @@ class Wav2Vec2PhonemeCTCTokenizerOutput(ModelOutput):
|
||||
Args:
|
||||
text (list of `str` or `str`):
|
||||
Decoded logits in text from. Usually the speech transcription.
|
||||
char_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`):
|
||||
char_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`):
|
||||
Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char
|
||||
offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with
|
||||
produced text.
|
||||
"""
|
||||
|
||||
text: Union[List[str], str]
|
||||
char_offsets: List[Dict[str, Union[float, str]]] = None
|
||||
char_offsets: Union[List[ListOfDict], ListOfDict] = None
|
||||
|
||||
|
||||
class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
|
||||
|
@ -19,7 +19,7 @@ import os
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import get_context
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -34,23 +34,30 @@ if TYPE_CHECKING:
|
||||
from ...tokenization_utils import PreTrainedTokenizerBase
|
||||
|
||||
|
||||
ListOfDict = List[Dict[str, Union[int, str]]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Wav2Vec2DecoderWithLMOutput(ModelOutput):
|
||||
"""
|
||||
Output type of [`Wav2Vec2DecoderWithLM`], with transcription.
|
||||
|
||||
Args:
|
||||
text (list of `str`):
|
||||
text (list of `str` or `str`):
|
||||
Decoded logits in text from. Usually the speech transcription.
|
||||
logit_score (list of `float`):
|
||||
logit_score (list of `float` or `float`):
|
||||
Total logit score of the beam associated with produced text.
|
||||
lm_score (list of `float`):
|
||||
Fused lm_score of the beam associated with produced text.
|
||||
word_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`):
|
||||
Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets
|
||||
can be used to compute time stamps for each word.
|
||||
"""
|
||||
|
||||
text: Union[List[str], str]
|
||||
logit_score: Union[List[float], float] = None
|
||||
lm_score: Union[List[float], float] = None
|
||||
word_offsets: Union[List[ListOfDict], ListOfDict] = None
|
||||
|
||||
|
||||
class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||
@ -232,6 +239,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||
beta: Optional[float] = None,
|
||||
unk_score_offset: Optional[float] = None,
|
||||
lm_score_boundary: Optional[bool] = None,
|
||||
output_word_offsets: bool = False,
|
||||
):
|
||||
"""
|
||||
Batch decode output logits to audio transcription with language model support.
|
||||
@ -267,6 +275,18 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||
Amount of log score offset for unknown tokens
|
||||
lm_score_boundary (`bool`, *optional*):
|
||||
Whether to have kenlm respect boundaries when scoring
|
||||
output_word_offsets (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate
|
||||
and model downsampling rate to compute the time-stamps of transcribed words.
|
||||
|
||||
<Tip>
|
||||
|
||||
Please take a look at the Example of [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to
|
||||
better understand how to make use of `output_word_offsets`.
|
||||
[`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.batch_decode`] works the same way with batched
|
||||
output.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
|
||||
@ -310,13 +330,18 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||
pool.close()
|
||||
|
||||
# extract text and scores
|
||||
batch_texts, logit_scores, lm_scores = [], [], []
|
||||
batch_texts, logit_scores, lm_scores, word_offsets = [], [], [], []
|
||||
for d in decoded_beams:
|
||||
batch_texts.append(d[0][0])
|
||||
logit_scores.append(d[0][-2])
|
||||
lm_scores.append(d[0][-1])
|
||||
# more output features will be added in the future
|
||||
return Wav2Vec2DecoderWithLMOutput(text=batch_texts, logit_score=logit_scores, lm_score=lm_scores)
|
||||
word_offsets.append([{"word": t[0], "start_offset": t[1][0], "end_offset": t[1][1]} for t in d[0][1]])
|
||||
|
||||
word_offsets = word_offsets if output_word_offsets else None
|
||||
|
||||
return Wav2Vec2DecoderWithLMOutput(
|
||||
text=batch_texts, logit_score=logit_scores, lm_score=lm_scores, word_offsets=word_offsets
|
||||
)
|
||||
|
||||
def decode(
|
||||
self,
|
||||
@ -330,6 +355,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||
beta: Optional[float] = None,
|
||||
unk_score_offset: Optional[float] = None,
|
||||
lm_score_boundary: Optional[bool] = None,
|
||||
output_word_offsets: bool = False,
|
||||
):
|
||||
"""
|
||||
Decode output logits to audio transcription with language model support.
|
||||
@ -357,11 +383,65 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||
Amount of log score offset for unknown tokens
|
||||
lm_score_boundary (`bool`, *optional*):
|
||||
Whether to have kenlm respect boundaries when scoring
|
||||
output_word_offsets (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate
|
||||
and model downsampling rate to compute the time-stamps of transcribed words.
|
||||
|
||||
<Tip>
|
||||
|
||||
Please take a look at the example of [`~models.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to
|
||||
better understand how to make use of `output_word_offsets`.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
|
||||
|
||||
"""
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> # Let's see how to retrieve time steps for a model
|
||||
>>> from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC
|
||||
>>> from datasets import load_dataset
|
||||
>>> import datasets
|
||||
>>> import torch
|
||||
|
||||
>>> # import model, feature extractor, tokenizer
|
||||
>>> model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
|
||||
>>> processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
|
||||
|
||||
>>> # load first sample of English common_voice
|
||||
>>> dataset = load_dataset("common_voice", "en", split="train", streaming=True)
|
||||
>>> dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||
>>> dataset_iter = iter(dataset)
|
||||
>>> sample = next(dataset_iter)
|
||||
|
||||
>>> # forward sample through model to get greedily predicted transcription ids
|
||||
>>> input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values
|
||||
>>> with torch.no_grad():
|
||||
... logits = model(input_values).logits[0].cpu().numpy()
|
||||
|
||||
>>> # retrieve word stamps (analogous commands for `output_char_offsets`)
|
||||
>>> outputs = tokenizer.decode(logits, output_word_offsets=True)
|
||||
>>> # compute `time_offset` in seconds as product of downsampling ratio and sampling_rate
|
||||
>>> time_offset = model.config.inputs_to_logits_ratio / feature_extractor.sampling_rate
|
||||
|
||||
>>> word_offsets = [
|
||||
... {
|
||||
... "word": d["word"],
|
||||
... "start_time": d["start_offset"] * time_offset,
|
||||
... "end_time": d["end_offset"] * time_offset,
|
||||
... }
|
||||
... for d in outputs.word_offsets
|
||||
... ]
|
||||
>>> # compare word offsets with audio `common_voice_en_100038.mp3` online on the dataset viewer:
|
||||
>>> # https://huggingface.co/datasets/common_voice/viewer/en/train
|
||||
>>> word_offset
|
||||
>>> # [{'word': 'WHY', 'start_time': 1.42, 'end_time': 1.54}, {'word': 'DOES',
|
||||
>>> # 'start_time': 1.64, 'end_time': 1.88}, {'word': 'A',
|
||||
>>> # 'start_time': 2.12, 'end_time': 2.14}, {'word': 'MILE', 'start_time': 2.26, 'end_time': 2.46}, ...
|
||||
```"""
|
||||
|
||||
from pyctcdecode.constants import (
|
||||
DEFAULT_BEAM_WIDTH,
|
||||
DEFAULT_HOTWORD_WEIGHT,
|
||||
@ -390,9 +470,19 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||
hotword_weight=hotword_weight,
|
||||
)
|
||||
|
||||
word_offsets = None
|
||||
if output_word_offsets:
|
||||
word_offsets = [
|
||||
{"word": word, "start_offset": start_offset, "end_offset": end_offset}
|
||||
for word, (start_offset, end_offset) in decoded_beams[0][2]
|
||||
]
|
||||
|
||||
# more output features will be added in the future
|
||||
return Wav2Vec2DecoderWithLMOutput(
|
||||
text=decoded_beams[0][0], logit_score=decoded_beams[0][-2], lm_score=decoded_beams[0][-1]
|
||||
text=decoded_beams[0][0],
|
||||
logit_score=decoded_beams[0][-2],
|
||||
lm_score=decoded_beams[0][-1],
|
||||
word_offsets=word_offsets,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
|
@ -20,13 +20,15 @@ import unittest
|
||||
from multiprocessing import get_context
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import AutoProcessor
|
||||
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available
|
||||
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available, is_torch_available
|
||||
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import require_pyctcdecode
|
||||
from transformers.testing_utils import require_pyctcdecode, require_torch, require_torchaudio, slow
|
||||
|
||||
from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list
|
||||
|
||||
@ -35,6 +37,10 @@ if is_pyctcdecode_available():
|
||||
from huggingface_hub import snapshot_download
|
||||
from pyctcdecode import BeamSearchDecoderCTC
|
||||
from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
||||
from transformers.models.wav2vec2_with_lm.processing_wav2vec2_with_lm import Wav2Vec2DecoderWithLMOutput
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import Wav2Vec2ForCTC
|
||||
|
||||
|
||||
@require_pyctcdecode
|
||||
@ -350,3 +356,101 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
decoded_auto = processor_auto.batch_decode(logits)
|
||||
|
||||
self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text)
|
||||
|
||||
@staticmethod
|
||||
def get_from_offsets(offsets, key):
|
||||
retrieved_list = [d[key] for d in offsets]
|
||||
return retrieved_list
|
||||
|
||||
def test_offsets_integration_fast(self):
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
|
||||
logits = self._get_dummy_logits()[0]
|
||||
|
||||
outputs = processor.decode(logits, output_word_offsets=True)
|
||||
# check Wav2Vec2CTCTokenizerOutput keys for word
|
||||
self.assertTrue(len(outputs.keys()), 2)
|
||||
self.assertTrue("text" in outputs)
|
||||
self.assertTrue("word_offsets" in outputs)
|
||||
self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput))
|
||||
|
||||
self.assertEqual(" ".join(self.get_from_offsets(outputs["word_offsets"], "word")), outputs.text)
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "word"), ["<s>", "<s>", "</s>"])
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "start_offset"), [0, 2, 4])
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "end_offset"), [1, 3, 5])
|
||||
|
||||
def test_offsets_integration_fast_batch(self):
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
|
||||
logits = self._get_dummy_logits()
|
||||
|
||||
outputs = processor.batch_decode(logits, output_word_offsets=True)
|
||||
|
||||
# check Wav2Vec2CTCTokenizerOutput keys for word
|
||||
self.assertTrue(len(outputs.keys()), 2)
|
||||
self.assertTrue("text" in outputs)
|
||||
self.assertTrue("word_offsets" in outputs)
|
||||
self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput))
|
||||
|
||||
self.assertListEqual(
|
||||
[" ".join(self.get_from_offsets(o, "word")) for o in outputs["word_offsets"]], outputs.text
|
||||
)
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "word"), ["<s>", "<s>", "</s>"])
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "start_offset"), [0, 2, 4])
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "end_offset"), [1, 3, 5])
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
def test_word_time_stamp_integration(self):
|
||||
import torch
|
||||
|
||||
ds = load_dataset("common_voice", "en", split="train", streaming=True)
|
||||
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||
ds_iter = iter(ds)
|
||||
sample = next(ds_iter)
|
||||
|
||||
processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
|
||||
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
|
||||
|
||||
# compare to filename `common_voice_en_100038.mp3` of dataset viewer on https://huggingface.co/datasets/common_voice/viewer/en/train
|
||||
input_values = processor(sample["audio"]["array"], return_tensors="pt").input_values
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_values).logits.cpu().numpy()
|
||||
|
||||
output = processor.decode(logits[0], output_word_offsets=True)
|
||||
|
||||
time_offset = model.config.inputs_to_logits_ratio / processor.feature_extractor.sampling_rate
|
||||
word_time_stamps = [
|
||||
{
|
||||
"start_time": d["start_offset"] * time_offset,
|
||||
"end_time": d["end_offset"] * time_offset,
|
||||
"word": d["word"],
|
||||
}
|
||||
for d in output["word_offsets"]
|
||||
]
|
||||
|
||||
EXPECTED_TEXT = "WHY DOES A MILE SANDRA LOOK LIKE SHE WANTS TO CONSUME JOHN SNOW ON THE RIVER AT THE WALL"
|
||||
|
||||
# output words
|
||||
self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), EXPECTED_TEXT)
|
||||
self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), output.text)
|
||||
|
||||
# output times
|
||||
start_times = [round(x, 2) for x in self.get_from_offsets(word_time_stamps, "start_time")]
|
||||
end_times = [round(x, 2) for x in self.get_from_offsets(word_time_stamps, "end_time")]
|
||||
|
||||
# fmt: off
|
||||
self.assertListEqual(
|
||||
start_times,
|
||||
[
|
||||
1.42, 1.64, 2.12, 2.26, 2.54, 3.0, 3.24, 3.6, 3.8, 4.1, 4.26, 4.94, 5.28, 5.66, 5.78, 5.94, 6.32, 6.54, 6.66,
|
||||
],
|
||||
)
|
||||
|
||||
self.assertListEqual(
|
||||
end_times,
|
||||
[
|
||||
1.54, 1.88, 2.14, 2.46, 2.9, 3.18, 3.54, 3.72, 4.02, 4.18, 4.76, 5.16, 5.56, 5.7, 5.86, 6.2, 6.38, 6.62, 6.94,
|
||||
],
|
||||
)
|
||||
# fmt: on
|
||||
|
Loading…
Reference in New Issue
Block a user