mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-06 14:20:04 +06:00

* fix_torch_device_generate_test * remove @ * start adding tests * correct wav2vec2 pretraining * up * up Co-authored-by: Patrick von Platen <patrick@huggingface.co>
417 lines
16 KiB
Python
417 lines
16 KiB
Python
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import inspect
|
|
import math
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
from transformers import Wav2Vec2Config, is_flax_available
|
|
from transformers.testing_utils import require_datasets, require_flax, require_soundfile, slow
|
|
|
|
from .test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, random_attention_mask
|
|
|
|
|
|
if is_flax_available():
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import optax
|
|
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
|
from transformers.models.wav2vec2.modeling_flax_wav2vec2 import (
|
|
FlaxWav2Vec2ForCTC,
|
|
FlaxWav2Vec2ForPreTraining,
|
|
FlaxWav2Vec2GumbelVectorQuantizer,
|
|
FlaxWav2Vec2Model,
|
|
_compute_mask_indices,
|
|
_sample_negative_indices,
|
|
)
|
|
|
|
|
|
class FlaxWav2Vec2ModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=13,
|
|
seq_length=1024, # speech is longer
|
|
is_training=False,
|
|
hidden_size=24,
|
|
feat_extract_norm="layer",
|
|
feat_extract_dropout=0.0,
|
|
feat_extract_activation="gelu",
|
|
conv_dim=(32, 32, 32),
|
|
conv_stride=(4, 4, 4),
|
|
conv_kernel=(8, 8, 8),
|
|
conv_bias=False,
|
|
num_conv_pos_embeddings=16,
|
|
num_conv_pos_embedding_groups=2,
|
|
num_hidden_layers=4,
|
|
num_attention_heads=2,
|
|
hidden_dropout_prob=0.1, # this is most likely not correctly set yet
|
|
intermediate_size=20,
|
|
layer_norm_eps=1e-5,
|
|
hidden_act="gelu",
|
|
initializer_range=0.02,
|
|
vocab_size=32,
|
|
do_stable_layer_norm=True,
|
|
scope=None,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length
|
|
self.is_training = is_training
|
|
self.hidden_size = hidden_size
|
|
self.feat_extract_norm = feat_extract_norm
|
|
self.feat_extract_dropout = feat_extract_dropout
|
|
self.feat_extract_activation = feat_extract_activation
|
|
self.conv_dim = conv_dim
|
|
self.conv_stride = conv_stride
|
|
self.conv_kernel = conv_kernel
|
|
self.conv_bias = conv_bias
|
|
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
|
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.hidden_dropout_prob = hidden_dropout_prob
|
|
self.intermediate_size = intermediate_size
|
|
self.layer_norm_eps = layer_norm_eps
|
|
self.hidden_act = hidden_act
|
|
self.initializer_range = initializer_range
|
|
self.vocab_size = vocab_size
|
|
self.do_stable_layer_norm = do_stable_layer_norm
|
|
self.scope = scope
|
|
|
|
output_seq_length = self.seq_length
|
|
for kernel, stride in zip(self.conv_kernel, self.conv_stride):
|
|
output_seq_length = (output_seq_length - (kernel - 1)) / stride
|
|
self.output_seq_length = int(math.ceil(output_seq_length))
|
|
self.encoder_seq_length = self.output_seq_length
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
|
|
|
|
config = Wav2Vec2Config(
|
|
do_stable_layer_norm=self.do_stable_layer_norm,
|
|
hidden_size=self.hidden_size,
|
|
feat_extract_norm=self.feat_extract_norm,
|
|
feat_extract_dropout=self.feat_extract_dropout,
|
|
feat_extract_activation=self.feat_extract_activation,
|
|
conv_dim=self.conv_dim,
|
|
conv_stride=self.conv_stride,
|
|
conv_kernel=self.conv_kernel,
|
|
conv_bias=self.conv_bias,
|
|
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
|
|
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
num_attention_heads=self.num_attention_heads,
|
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
intermediate_size=self.intermediate_size,
|
|
layer_norm_eps=self.layer_norm_eps,
|
|
hidden_act=self.hidden_act,
|
|
initializer_range=self.initializer_range,
|
|
vocab_size=self.vocab_size,
|
|
)
|
|
|
|
return config, input_values, attention_mask
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
config, input_values, attention_mask = config_and_inputs
|
|
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_flax
|
|
class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|
all_model_classes = (
|
|
(FlaxWav2Vec2Model, FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForPreTraining) if is_flax_available() else ()
|
|
)
|
|
|
|
def setUp(self):
|
|
self.model_tester = FlaxWav2Vec2ModelTester(self)
|
|
|
|
def test_train(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
input_values = inputs_dict["input_values"]
|
|
attention_mask = inputs_dict["attention_mask"]
|
|
|
|
model = FlaxWav2Vec2ForPreTraining(config)
|
|
|
|
features_shape = (
|
|
input_values.shape[0],
|
|
model._get_feat_extract_output_lengths(np.array(input_values.shape[1])),
|
|
)
|
|
|
|
batch_size, sequence_length = features_shape[:2]
|
|
|
|
mask_prob = 0.5
|
|
mask_length = 4
|
|
mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
|
|
|
dropout_rng, gumbel_rng = jax.random.split(jax.random.PRNGKey(0))
|
|
|
|
output = model(
|
|
input_values,
|
|
attention_mask=attention_mask,
|
|
mask_time_indices=mask_time_indices,
|
|
train=True,
|
|
dropout_rng=dropout_rng,
|
|
gumbel_rng=gumbel_rng,
|
|
)[0]
|
|
|
|
self.assertTrue(output.shape == (batch_size, sequence_length, model.config.proj_codevector_dim))
|
|
|
|
# overwrite because of `input_values`
|
|
def test_forward_signature(self):
|
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
signature = inspect.signature(model.__call__)
|
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
|
arg_names = [*signature.parameters.keys()]
|
|
|
|
expected_arg_names = ["input_values", "attention_mask"]
|
|
self.assertListEqual(arg_names[:2], expected_arg_names)
|
|
|
|
@slow
|
|
# overwrite because of `input_values`
|
|
def test_jit_compilation(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
for model_class in self.all_model_classes:
|
|
with self.subTest(model_class.__name__):
|
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
|
model = model_class(config)
|
|
|
|
@jax.jit
|
|
def model_jitted(input_values, attention_mask=None, **kwargs):
|
|
return model(input_values=input_values, attention_mask=attention_mask, **kwargs)
|
|
|
|
with self.subTest("JIT Enabled"):
|
|
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
|
|
|
with self.subTest("JIT Disabled"):
|
|
with jax.disable_jit():
|
|
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
|
|
|
self.assertEqual(len(outputs), len(jitted_outputs))
|
|
for jitted_output, output in zip(jitted_outputs, outputs):
|
|
|
|
self.assertEqual(jitted_output.shape, output.shape)
|
|
|
|
@slow
|
|
def test_model_from_pretrained(self):
|
|
for model_class_name in self.all_model_classes:
|
|
model = model_class_name.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", from_pt=True)
|
|
outputs = model(np.ones((1, 1024), dtype="f4"))
|
|
self.assertIsNotNone(outputs)
|
|
|
|
|
|
@require_flax
|
|
class FlaxWav2Vec2UtilsTest(unittest.TestCase):
|
|
def test_compute_mask_indices(self):
|
|
batch_size = 4
|
|
sequence_length = 60
|
|
mask_prob = 0.5
|
|
mask_length = 1
|
|
|
|
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
|
|
|
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
|
|
|
|
def test_compute_mask_indices_overlap(self):
|
|
batch_size = 4
|
|
sequence_length = 80
|
|
mask_prob = 0.5
|
|
mask_length = 4
|
|
|
|
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
|
|
|
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
|
|
for batch_sum in mask.sum(axis=-1):
|
|
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
|
|
|
def test_compute_mask_indices_attn_mask_overlap(self):
|
|
batch_size = 4
|
|
sequence_length = 80
|
|
mask_prob = 0.5
|
|
mask_length = 4
|
|
|
|
attention_mask = np.ones((batch_size, sequence_length), dtype=np.int32)
|
|
attention_mask[:2, sequence_length // 2 :] = 0
|
|
|
|
mask = _compute_mask_indices(
|
|
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
|
|
)
|
|
|
|
for batch_sum in mask.sum(axis=-1):
|
|
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
|
|
|
self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
|
|
|
|
def test_compute_perplexity(self):
|
|
probs = np.arange(100).reshape(2, 5, 10) / 100
|
|
|
|
ppl = FlaxWav2Vec2GumbelVectorQuantizer._compute_perplexity(probs)
|
|
self.assertTrue(abs(ppl.item() - 141.4291) < 1e-3)
|
|
|
|
# mask half of the input
|
|
mask = np.ones((2,), dtype=np.bool)
|
|
mask[0] = 0
|
|
|
|
ppl = FlaxWav2Vec2GumbelVectorQuantizer._compute_perplexity(probs, mask)
|
|
self.assertTrue(abs(ppl.item() - 58.6757) < 1e-3)
|
|
|
|
def test_sample_negatives(self):
|
|
batch_size = 2
|
|
sequence_length = 10
|
|
hidden_size = 4
|
|
num_negatives = 3
|
|
|
|
features = (np.arange(sequence_length * hidden_size) // hidden_size).reshape(
|
|
sequence_length, hidden_size
|
|
) # each value in vector consits of same value
|
|
features = np.broadcast_to(features[None, :], (batch_size, sequence_length, hidden_size))
|
|
|
|
negative_indices = _sample_negative_indices(features.shape, num_negatives)
|
|
|
|
features = features.reshape(-1, hidden_size) # BTC => (BxT)C
|
|
# take negative vectors from sampled indices
|
|
sampled_negatives = features[negative_indices.reshape(-1)]
|
|
negatives = sampled_negatives.reshape(batch_size, sequence_length, num_negatives, hidden_size).transpose(
|
|
2, 0, 1, 3
|
|
)
|
|
|
|
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
|
|
|
|
# make sure no negatively sampled vector is actually a positive one
|
|
for negative in negatives:
|
|
self.assertTrue(((negative - features.reshape(negative.shape)) == 0).sum() == 0.0)
|
|
|
|
# make sure that full vectors are sampled and not values of vectors
|
|
# => this means that `unique()` yields a single value for `hidden_size` dim
|
|
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
|
|
|
|
|
@require_flax
|
|
@require_datasets
|
|
@require_soundfile
|
|
@slow
|
|
class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|
def _load_datasamples(self, num_samples):
|
|
from datasets import load_dataset
|
|
|
|
import soundfile as sf
|
|
|
|
ids = [f"1272-141231-000{i}" for i in range(num_samples)]
|
|
|
|
# map files to raw
|
|
def map_to_array(batch):
|
|
speech, _ = sf.read(batch["file"])
|
|
batch["speech"] = speech
|
|
return batch
|
|
|
|
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
|
|
|
ds = ds.filter(lambda x: x["id"] in ids).sort("id").map(map_to_array)
|
|
|
|
return ds["speech"][:num_samples]
|
|
|
|
def test_inference_ctc_robust_batched(self):
|
|
model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", from_pt=True)
|
|
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
|
|
|
|
input_speech = self._load_datasamples(4)
|
|
|
|
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
|
|
|
|
input_values = inputs.input_values
|
|
attention_mask = inputs.attention_mask
|
|
|
|
logits = model(input_values, attention_mask=attention_mask).logits
|
|
|
|
predicted_ids = jnp.argmax(logits, axis=-1)
|
|
predicted_trans = processor.batch_decode(predicted_ids)
|
|
|
|
EXPECTED_TRANSCRIPTIONS = [
|
|
"a man said to the universe sir i exist",
|
|
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
|
|
"the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
|
|
"his instant panic was followed by a small sharp blow high on his chest",
|
|
]
|
|
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
|
|
|
def test_inference_pretrained(self):
|
|
model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60", from_pt=True)
|
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
|
"facebook/wav2vec2-large-lv60", return_attention_mask=True
|
|
)
|
|
input_speech = self._load_datasamples(2)
|
|
|
|
inputs_dict = feature_extractor(input_speech, return_tensors="np", padding=True)
|
|
|
|
features_shape = (
|
|
inputs_dict["input_values"].shape[0],
|
|
model._get_feat_extract_output_lengths(np.array(inputs_dict["input_values"].shape[1])),
|
|
)
|
|
|
|
mask_time_indices = _compute_mask_indices(
|
|
features_shape,
|
|
model.config.mask_time_prob,
|
|
model.config.mask_time_length,
|
|
min_masks=2,
|
|
)
|
|
|
|
outputs = model(
|
|
inputs_dict.input_values,
|
|
attention_mask=inputs_dict.attention_mask,
|
|
mask_time_indices=mask_time_indices,
|
|
)
|
|
|
|
# compute cosine similarity
|
|
cosine_sim = optax.cosine_similarity(
|
|
outputs.projected_states, outputs.projected_quantized_states, epsilon=1e-8
|
|
)
|
|
|
|
# retrieve cosine sim of masked features
|
|
cosine_sim_masked = cosine_sim[mask_time_indices]
|
|
|
|
# ... now compare to randomly initialized model
|
|
|
|
config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-large-lv60")
|
|
model_rand = FlaxWav2Vec2ForPreTraining(config)
|
|
|
|
outputs_rand = model_rand(
|
|
inputs_dict.input_values,
|
|
attention_mask=inputs_dict.attention_mask,
|
|
mask_time_indices=mask_time_indices,
|
|
)
|
|
|
|
# compute cosine similarity
|
|
cosine_sim_rand = optax.cosine_similarity(
|
|
outputs_rand.projected_states, outputs_rand.projected_quantized_states
|
|
)
|
|
|
|
# retrieve cosine sim of masked features
|
|
cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices]
|
|
|
|
# a pretrained wav2vec2 model has learned to predict the quantized latent states
|
|
# => the cosine similarity between quantized states and predicted states > 0.5
|
|
# a random wav2vec2 model has not learned to predict the quantized latent states
|
|
# => the cosine similarity between quantized states and predicted states is very likely < 0.1
|
|
self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
|