mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00

* up * up * up * make it cleaner * correct * make styhahalal * add more tests * finish * small fix * make style * up * tryout to solve cicrle ci * up * fix more tests * fix more tests * apply sylvains suggestions * fix import * correct docs * add pyctcdecode only to speech tests * fix more tests * add tf, flax and pt tests * add pt * fix last tests * fix more tests * Apply suggestions from code review * change lines * Apply suggestions from code review Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com> * correct tests * correct tests * add doc string Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
485 lines
19 KiB
Python
485 lines
19 KiB
Python
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import inspect
|
|
import math
|
|
import unittest
|
|
|
|
import numpy as np
|
|
from datasets import load_dataset
|
|
|
|
from transformers import Wav2Vec2Config, is_flax_available
|
|
from transformers.testing_utils import (
|
|
is_librosa_available,
|
|
is_pyctcdecode_available,
|
|
require_datasets,
|
|
require_flax,
|
|
require_librosa,
|
|
require_pyctcdecode,
|
|
require_soundfile,
|
|
slow,
|
|
)
|
|
|
|
from .test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, random_attention_mask
|
|
|
|
|
|
if is_flax_available():
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import optax
|
|
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
|
from transformers.models.wav2vec2.modeling_flax_wav2vec2 import (
|
|
FlaxWav2Vec2ForCTC,
|
|
FlaxWav2Vec2ForPreTraining,
|
|
FlaxWav2Vec2GumbelVectorQuantizer,
|
|
FlaxWav2Vec2Model,
|
|
_compute_mask_indices,
|
|
_sample_negative_indices,
|
|
)
|
|
|
|
|
|
if is_pyctcdecode_available():
|
|
from transformers import Wav2Vec2ProcessorWithLM
|
|
|
|
|
|
if is_librosa_available():
|
|
import librosa
|
|
|
|
|
|
class FlaxWav2Vec2ModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=13,
|
|
seq_length=1024, # speech is longer
|
|
is_training=False,
|
|
hidden_size=24,
|
|
feat_extract_norm="layer",
|
|
feat_extract_dropout=0.0,
|
|
feat_extract_activation="gelu",
|
|
conv_dim=(32, 32, 32),
|
|
conv_stride=(4, 4, 4),
|
|
conv_kernel=(8, 8, 8),
|
|
conv_bias=False,
|
|
num_conv_pos_embeddings=16,
|
|
num_conv_pos_embedding_groups=2,
|
|
num_hidden_layers=4,
|
|
num_attention_heads=2,
|
|
hidden_dropout_prob=0.1, # this is most likely not correctly set yet
|
|
intermediate_size=20,
|
|
layer_norm_eps=1e-5,
|
|
hidden_act="gelu",
|
|
initializer_range=0.02,
|
|
vocab_size=32,
|
|
do_stable_layer_norm=True,
|
|
scope=None,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length
|
|
self.is_training = is_training
|
|
self.hidden_size = hidden_size
|
|
self.feat_extract_norm = feat_extract_norm
|
|
self.feat_extract_dropout = feat_extract_dropout
|
|
self.feat_extract_activation = feat_extract_activation
|
|
self.conv_dim = conv_dim
|
|
self.conv_stride = conv_stride
|
|
self.conv_kernel = conv_kernel
|
|
self.conv_bias = conv_bias
|
|
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
|
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.hidden_dropout_prob = hidden_dropout_prob
|
|
self.intermediate_size = intermediate_size
|
|
self.layer_norm_eps = layer_norm_eps
|
|
self.hidden_act = hidden_act
|
|
self.initializer_range = initializer_range
|
|
self.vocab_size = vocab_size
|
|
self.do_stable_layer_norm = do_stable_layer_norm
|
|
self.scope = scope
|
|
|
|
output_seq_length = self.seq_length
|
|
for kernel, stride in zip(self.conv_kernel, self.conv_stride):
|
|
output_seq_length = (output_seq_length - (kernel - 1)) / stride
|
|
self.output_seq_length = int(math.ceil(output_seq_length))
|
|
self.encoder_seq_length = self.output_seq_length
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
|
|
|
|
config = Wav2Vec2Config(
|
|
do_stable_layer_norm=self.do_stable_layer_norm,
|
|
hidden_size=self.hidden_size,
|
|
feat_extract_norm=self.feat_extract_norm,
|
|
feat_extract_dropout=self.feat_extract_dropout,
|
|
feat_extract_activation=self.feat_extract_activation,
|
|
conv_dim=self.conv_dim,
|
|
conv_stride=self.conv_stride,
|
|
conv_kernel=self.conv_kernel,
|
|
conv_bias=self.conv_bias,
|
|
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
|
|
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
num_attention_heads=self.num_attention_heads,
|
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
intermediate_size=self.intermediate_size,
|
|
layer_norm_eps=self.layer_norm_eps,
|
|
hidden_act=self.hidden_act,
|
|
initializer_range=self.initializer_range,
|
|
vocab_size=self.vocab_size,
|
|
)
|
|
|
|
return config, input_values, attention_mask
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
config, input_values, attention_mask = config_and_inputs
|
|
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_flax
|
|
class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|
all_model_classes = (
|
|
(FlaxWav2Vec2Model, FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForPreTraining) if is_flax_available() else ()
|
|
)
|
|
|
|
def setUp(self):
|
|
self.model_tester = FlaxWav2Vec2ModelTester(self)
|
|
|
|
def test_train(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
input_values = inputs_dict["input_values"]
|
|
attention_mask = inputs_dict["attention_mask"]
|
|
|
|
model = FlaxWav2Vec2ForPreTraining(config)
|
|
|
|
features_shape = (
|
|
input_values.shape[0],
|
|
model._get_feat_extract_output_lengths(np.array(input_values.shape[1])),
|
|
)
|
|
|
|
batch_size, sequence_length = features_shape[:2]
|
|
|
|
mask_prob = 0.5
|
|
mask_length = 4
|
|
mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
|
|
|
dropout_rng, gumbel_rng = jax.random.split(jax.random.PRNGKey(0))
|
|
|
|
output = model(
|
|
input_values,
|
|
attention_mask=attention_mask,
|
|
mask_time_indices=mask_time_indices,
|
|
train=True,
|
|
dropout_rng=dropout_rng,
|
|
gumbel_rng=gumbel_rng,
|
|
)[0]
|
|
|
|
self.assertTrue(output.shape == (batch_size, sequence_length, model.config.proj_codevector_dim))
|
|
|
|
# overwrite because of `input_values`
|
|
def test_forward_signature(self):
|
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
signature = inspect.signature(model.__call__)
|
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
|
arg_names = [*signature.parameters.keys()]
|
|
|
|
expected_arg_names = ["input_values", "attention_mask"]
|
|
self.assertListEqual(arg_names[:2], expected_arg_names)
|
|
|
|
# overwrite because of `input_values`
|
|
def test_jit_compilation(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
for model_class in self.all_model_classes:
|
|
with self.subTest(model_class.__name__):
|
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
|
model = model_class(config)
|
|
|
|
@jax.jit
|
|
def model_jitted(input_values, attention_mask=None, **kwargs):
|
|
return model(input_values=input_values, attention_mask=attention_mask, **kwargs)
|
|
|
|
with self.subTest("JIT Enabled"):
|
|
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
|
|
|
with self.subTest("JIT Disabled"):
|
|
with jax.disable_jit():
|
|
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
|
|
|
self.assertEqual(len(outputs), len(jitted_outputs))
|
|
for jitted_output, output in zip(jitted_outputs, outputs):
|
|
|
|
self.assertEqual(jitted_output.shape, output.shape)
|
|
|
|
@slow
|
|
def test_model_from_pretrained(self):
|
|
for model_class_name in self.all_model_classes:
|
|
model = model_class_name.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", from_pt=True)
|
|
outputs = model(np.ones((1, 1024), dtype="f4"))
|
|
self.assertIsNotNone(outputs)
|
|
|
|
|
|
@require_flax
|
|
class FlaxWav2Vec2UtilsTest(unittest.TestCase):
|
|
def test_compute_mask_indices(self):
|
|
batch_size = 4
|
|
sequence_length = 60
|
|
mask_prob = 0.5
|
|
mask_length = 1
|
|
|
|
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
|
|
|
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
|
|
|
|
def test_compute_mask_indices_overlap(self):
|
|
batch_size = 4
|
|
sequence_length = 80
|
|
mask_prob = 0.5
|
|
mask_length = 4
|
|
|
|
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
|
|
|
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
|
|
for batch_sum in mask.sum(axis=-1):
|
|
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
|
|
|
def test_compute_mask_indices_attn_mask_overlap(self):
|
|
batch_size = 4
|
|
sequence_length = 80
|
|
mask_prob = 0.5
|
|
mask_length = 4
|
|
|
|
attention_mask = np.ones((batch_size, sequence_length), dtype=np.int32)
|
|
attention_mask[:2, sequence_length // 2 :] = 0
|
|
|
|
mask = _compute_mask_indices(
|
|
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
|
|
)
|
|
|
|
for batch_sum in mask.sum(axis=-1):
|
|
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
|
|
|
self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
|
|
|
|
def test_compute_perplexity(self):
|
|
probs = np.arange(100).reshape(2, 5, 10) / 100
|
|
|
|
ppl = FlaxWav2Vec2GumbelVectorQuantizer._compute_perplexity(probs)
|
|
self.assertTrue(abs(ppl.item() - 141.4291) < 1e-3)
|
|
|
|
# mask half of the input
|
|
mask = np.ones((2,), dtype=np.bool)
|
|
mask[0] = 0
|
|
|
|
ppl = FlaxWav2Vec2GumbelVectorQuantizer._compute_perplexity(probs, mask)
|
|
self.assertTrue(abs(ppl.item() - 58.6757) < 1e-3)
|
|
|
|
def test_sample_negatives(self):
|
|
batch_size = 2
|
|
sequence_length = 10
|
|
hidden_size = 4
|
|
num_negatives = 3
|
|
|
|
features = (np.arange(sequence_length * hidden_size) // hidden_size).reshape(
|
|
sequence_length, hidden_size
|
|
) # each value in vector consits of same value
|
|
features = np.broadcast_to(features[None, :], (batch_size, sequence_length, hidden_size))
|
|
|
|
negative_indices = _sample_negative_indices(features.shape, num_negatives)
|
|
|
|
features = features.reshape(-1, hidden_size) # BTC => (BxT)C
|
|
# take negative vectors from sampled indices
|
|
sampled_negatives = features[negative_indices.reshape(-1)]
|
|
negatives = sampled_negatives.reshape(batch_size, sequence_length, num_negatives, hidden_size).transpose(
|
|
2, 0, 1, 3
|
|
)
|
|
|
|
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
|
|
|
|
# make sure no negatively sampled vector is actually a positive one
|
|
for negative in negatives:
|
|
self.assertTrue(((negative - features.reshape(negative.shape)) == 0).sum() == 0.0)
|
|
|
|
# make sure that full vectors are sampled and not values of vectors
|
|
# => this means that `unique()` yields a single value for `hidden_size` dim
|
|
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
|
|
|
def test_sample_negatives_with_attn_mask(self):
|
|
batch_size = 2
|
|
sequence_length = 10
|
|
hidden_size = 4
|
|
num_negatives = 3
|
|
|
|
features = (np.arange(sequence_length * hidden_size) // hidden_size).reshape(
|
|
sequence_length, hidden_size
|
|
) # each value in vector consits of same value
|
|
|
|
# second half of last input tensor is padded
|
|
attention_mask = np.ones((batch_size, sequence_length), dtype=np.int8)
|
|
attention_mask[-1, sequence_length // 2 :] = 0
|
|
|
|
forbidden_indices = (
|
|
np.arange(sequence_length // 2, sequence_length, dtype=np.int32) + (batch_size - 1) * sequence_length
|
|
).tolist()
|
|
|
|
features = np.broadcast_to(features[None, :], (batch_size, sequence_length, hidden_size))
|
|
|
|
negative_indices = _sample_negative_indices(features.shape, num_negatives, attention_mask=attention_mask)
|
|
|
|
# make sure that no padding tokens are sampled
|
|
self.assertTrue(all([idx not in negative_indices for idx in forbidden_indices]))
|
|
|
|
features = features.reshape(-1, hidden_size) # BTC => (BxT)C
|
|
# take negative vectors from sampled indices
|
|
sampled_negatives = features[negative_indices.reshape(-1)]
|
|
negatives = sampled_negatives.reshape(batch_size, sequence_length, num_negatives, hidden_size).transpose(
|
|
2, 0, 1, 3
|
|
)
|
|
|
|
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
|
|
|
|
# make sure no negatively sampled vector is actually a positive one
|
|
for negative in negatives:
|
|
self.assertTrue(((negative - features.reshape(negative.shape)) == 0).sum() == 0.0)
|
|
|
|
# make sure that full vectors are sampled and not just slices of vectors
|
|
# => this means that `unique()` yields a single value for `hidden_size` dim
|
|
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
|
|
|
|
|
@require_flax
|
|
@require_datasets
|
|
@require_soundfile
|
|
@slow
|
|
class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|
def _load_datasamples(self, num_samples):
|
|
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
|
# automatic decoding with librispeech
|
|
speech_samples = ds.sort("id").filter(
|
|
lambda x: x["id"] in [f"1272-141231-000{i}" for i in range(num_samples)]
|
|
)[:num_samples]["audio"]
|
|
|
|
return [x["array"] for x in speech_samples]
|
|
|
|
def test_inference_ctc_robust_batched(self):
|
|
model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", from_pt=True)
|
|
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
|
|
|
|
input_speech = self._load_datasamples(4)
|
|
|
|
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
|
|
|
|
input_values = inputs.input_values
|
|
attention_mask = inputs.attention_mask
|
|
|
|
logits = model(input_values, attention_mask=attention_mask).logits
|
|
|
|
predicted_ids = jnp.argmax(logits, axis=-1)
|
|
predicted_trans = processor.batch_decode(predicted_ids)
|
|
|
|
EXPECTED_TRANSCRIPTIONS = [
|
|
"a man said to the universe sir i exist",
|
|
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
|
|
"the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
|
|
"his instant panic was followed by a small sharp blow high on his chest",
|
|
]
|
|
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
|
|
|
def test_inference_pretrained(self):
|
|
model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60", from_pt=True)
|
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
|
"facebook/wav2vec2-large-lv60", return_attention_mask=True
|
|
)
|
|
input_speech = self._load_datasamples(2)
|
|
|
|
inputs_dict = feature_extractor(input_speech, return_tensors="np", padding=True)
|
|
|
|
features_shape = (
|
|
inputs_dict["input_values"].shape[0],
|
|
model._get_feat_extract_output_lengths(np.array(inputs_dict["input_values"].shape[1])),
|
|
)
|
|
|
|
mask_time_indices = _compute_mask_indices(
|
|
features_shape,
|
|
model.config.mask_time_prob,
|
|
model.config.mask_time_length,
|
|
min_masks=2,
|
|
)
|
|
|
|
outputs = model(
|
|
inputs_dict.input_values,
|
|
attention_mask=inputs_dict.attention_mask,
|
|
mask_time_indices=mask_time_indices,
|
|
)
|
|
|
|
# compute cosine similarity
|
|
cosine_sim = optax.cosine_similarity(
|
|
outputs.projected_states, outputs.projected_quantized_states, epsilon=1e-8
|
|
)
|
|
|
|
# retrieve cosine sim of masked features
|
|
cosine_sim_masked = cosine_sim[mask_time_indices]
|
|
|
|
# ... now compare to randomly initialized model
|
|
|
|
config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-large-lv60")
|
|
model_rand = FlaxWav2Vec2ForPreTraining(config)
|
|
|
|
outputs_rand = model_rand(
|
|
inputs_dict.input_values,
|
|
attention_mask=inputs_dict.attention_mask,
|
|
mask_time_indices=mask_time_indices,
|
|
)
|
|
|
|
# compute cosine similarity
|
|
cosine_sim_rand = optax.cosine_similarity(
|
|
outputs_rand.projected_states, outputs_rand.projected_quantized_states
|
|
)
|
|
|
|
# retrieve cosine sim of masked features
|
|
cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices]
|
|
|
|
# a pretrained wav2vec2 model has learned to predict the quantized latent states
|
|
# => the cosine similarity between quantized states and predicted states > 0.5
|
|
# a random wav2vec2 model has not learned to predict the quantized latent states
|
|
# => the cosine similarity between quantized states and predicted states is very likely < 0.1
|
|
self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
|
|
|
|
@require_pyctcdecode
|
|
@require_librosa
|
|
def test_wav2vec2_with_lm(self):
|
|
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
|
sample = next(iter(ds))
|
|
|
|
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
|
|
|
|
model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
|
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
|
|
|
input_values = processor(resampled_audio, return_tensors="np").input_values
|
|
|
|
logits = model(input_values).logits
|
|
|
|
transcription = processor.batch_decode(np.array(logits)).text
|
|
|
|
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|