mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-05 22:00:09 +06:00

* Fix TypeError: Object of type int64 is not JSON serializable * Convert numpy.float64 and numpy.int64 to float and int for json serialization * Black reformatted examples/pytorch/token-classification/run_ner_no_trainer.py * * make style
652 lines
26 KiB
Python
652 lines
26 KiB
Python
# coding=utf-8
|
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" Testing suite for the PyTorch MCTCT model. """
|
|
|
|
import inspect
|
|
import math
|
|
import unittest
|
|
|
|
from datasets import load_dataset
|
|
|
|
from transformers import MCTCTConfig, is_torch_available
|
|
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
|
|
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import MCTCTForCTC, MCTCTModel, MCTCTProcessor
|
|
|
|
|
|
class MCTCTModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=10,
|
|
seq_length=40, # speech is longer
|
|
is_training=False,
|
|
vocab_size=32,
|
|
hidden_size=128 * 4,
|
|
num_hidden_layers=4,
|
|
intermediate_size=20,
|
|
num_attention_heads=4,
|
|
attention_head_dim=128,
|
|
max_position_embeddings=920,
|
|
layer_norm_eps=1e-5,
|
|
layerdrop=0.3,
|
|
hidden_act="relu",
|
|
initializer_range=0.02,
|
|
hidden_dropout_prob=0.3,
|
|
attention_probs_dropout_prob=0.3,
|
|
conv_glu_dim=1,
|
|
conv_dropout=0.3,
|
|
num_conv_layers=1,
|
|
conv_kernel=(7,),
|
|
conv_stride=(3,),
|
|
input_feat_per_channel=80,
|
|
input_channels=1,
|
|
conv_channels=None,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length # speech is longer
|
|
self.is_training = is_training
|
|
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.intermediate_size = intermediate_size
|
|
self.num_attention_heads = num_attention_heads
|
|
|
|
self.attention_head_dim = attention_head_dim
|
|
self.max_position_embeddings = max_position_embeddings
|
|
|
|
self.layer_norm_eps = layer_norm_eps
|
|
self.layerdrop = layerdrop
|
|
self.hidden_act = hidden_act
|
|
self.initializer_range = initializer_range
|
|
self.hidden_dropout_prob = hidden_dropout_prob
|
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
|
|
self.conv_glu_dim = conv_glu_dim
|
|
self.conv_dropout = conv_dropout
|
|
self.num_conv_layers = num_conv_layers
|
|
self.conv_kernel = conv_kernel
|
|
self.conv_stride = conv_stride
|
|
self.input_feat_per_channel = input_feat_per_channel
|
|
self.input_channels = input_channels
|
|
self.conv_channels = conv_channels
|
|
|
|
output_seq_length = self.seq_length
|
|
dilation = 1
|
|
for _, kernel_sz, stride in zip(range(self.num_conv_layers), self.conv_kernel, self.conv_stride):
|
|
padding = kernel_sz // 2
|
|
output_seq_length = output_seq_length + 2 * padding - dilation * (kernel_sz - 1) - 1
|
|
output_seq_length = torch.div(output_seq_length, stride, rounding_mode="trunc") + 1
|
|
|
|
self.output_seq_length = int(math.ceil(output_seq_length))
|
|
self.encoder_seq_length = self.output_seq_length
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_features = floats_tensor(
|
|
[self.batch_size, self.seq_length, self.input_feat_per_channel], self.vocab_size
|
|
)
|
|
attention_mask = torch.ones([self.batch_size, self.seq_length], dtype=torch.long, device=torch_device)
|
|
|
|
config = self.get_config()
|
|
|
|
return config, input_features, attention_mask
|
|
|
|
def get_config(self):
|
|
return MCTCTConfig(
|
|
vocab_size=self.vocab_size,
|
|
hidden_size=self.hidden_size,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
intermediate_size=self.intermediate_size,
|
|
num_attention_heads=self.num_attention_heads,
|
|
attention_head_dim=self.attention_head_dim,
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
layer_norm_eps=self.layer_norm_eps,
|
|
layerdrop=self.layerdrop,
|
|
hidden_act=self.hidden_act,
|
|
initializer_range=self.initializer_range,
|
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
conv_glu_dim=self.conv_glu_dim,
|
|
conv_dropout=self.conv_dropout,
|
|
num_conv_layers=self.num_conv_layers,
|
|
conv_kernel=self.conv_kernel,
|
|
conv_stride=self.conv_stride,
|
|
input_feat_per_channel=self.input_feat_per_channel,
|
|
input_channels=self.input_channels,
|
|
conv_channels=self.conv_channels,
|
|
)
|
|
|
|
def create_and_check_model(self, config, input_features, attention_mask):
|
|
model = MCTCTModel(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_features, attention_mask=attention_mask)
|
|
|
|
self.parent.assertEqual(
|
|
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
|
|
)
|
|
|
|
def create_and_check_model_for_ctc(self, config, input_features, attention_mask):
|
|
config.add_adapter = True
|
|
config.output_hidden_size = 2 * config.hidden_size
|
|
model = MCTCTForCTC(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_features, attention_mask=attention_mask)
|
|
self.parent.assertEqual(
|
|
result.logits.shape, (self.batch_size, self.adapter_output_seq_length, self.vocab_size)
|
|
)
|
|
|
|
def create_and_check_batch_inference(self, config, input_features, *args):
|
|
# test does not pass for models making use of `group_norm`
|
|
# check: https://github.com/pytorch/fairseq/issues/3227
|
|
model = MCTCTModel(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
input_features = input_features[:3]
|
|
attention_mask = torch.ones(input_features.shape[:-1], device=torch_device, dtype=torch.bool)
|
|
|
|
input_lengths = [input_features.shape[-1] // i for i in [2, 2, 1]]
|
|
|
|
# pad input
|
|
for i in range(len(input_lengths)):
|
|
input_features[i, input_lengths[i] :] = 0.0
|
|
attention_mask[i, input_lengths[i] :] = 0.0
|
|
|
|
batch_outputs = model(input_features, attention_mask=attention_mask).last_hidden_state
|
|
|
|
for i in range(input_features.shape[0]):
|
|
input_slice = input_features[i : i + 1, : input_lengths[i]]
|
|
output = model(input_slice).last_hidden_state
|
|
|
|
batch_output = batch_outputs[i : i + 1, : output.shape[1]]
|
|
self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
|
|
|
|
def check_ctc_loss(self, config, input_features, *args):
|
|
model = MCTCTForCTC(config=config)
|
|
model.to(torch_device)
|
|
|
|
# make sure that dropout is disabled
|
|
model.eval()
|
|
|
|
input_features = input_features[:3]
|
|
|
|
# input_features is a 2D window for each sequence
|
|
attention_mask = torch.ones(input_features.shape[:-1], device=torch_device, dtype=torch.long)
|
|
|
|
# -2 since input_features is a 2D window for each sequence in batch
|
|
input_lengths = [input_features.shape[-2] // i for i in [2, 2, 1]]
|
|
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
|
|
labels = ids_tensor((input_features.shape[0], min(max_length_labels) - 1), model.config.vocab_size)
|
|
# pad input
|
|
for i in range(len(input_lengths)):
|
|
input_features[i, input_lengths[i] :] = 0.0
|
|
attention_mask[i, input_lengths[i] :] = 0
|
|
|
|
model.config.ctc_loss_reduction = "sum"
|
|
sum_loss = model(input_features, attention_mask=attention_mask, labels=labels).loss.item()
|
|
|
|
model.config.ctc_loss_reduction = "mean"
|
|
mean_loss = model(input_features, attention_mask=attention_mask, labels=labels).loss.item()
|
|
|
|
self.parent.assertTrue(isinstance(sum_loss, float))
|
|
self.parent.assertTrue(isinstance(mean_loss, float))
|
|
|
|
def check_ctc_training(self, config, input_features, *args):
|
|
config.ctc_zero_infinity = True
|
|
model = MCTCTForCTC(config=config)
|
|
model.to(torch_device)
|
|
model.train()
|
|
|
|
input_features = input_features[:3]
|
|
|
|
input_lengths = [input_features.shape[-2] // i for i in [2, 2, 1]]
|
|
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
|
|
labels = ids_tensor((input_features.shape[0], max(max_length_labels) - 1), model.config.vocab_size)
|
|
|
|
# pad input
|
|
for i in range(len(input_lengths)):
|
|
input_features[i, input_lengths[i] :] = 0.0
|
|
|
|
if max_length_labels[i] < labels.shape[-1]:
|
|
# it's important that we make sure that target lenghts are at least
|
|
# one shorter than logit lenghts to prevent -inf
|
|
labels[i, max_length_labels[i] - 1 :] = -100
|
|
|
|
loss = model(input_features, labels=labels).loss
|
|
self.parent.assertFalse(torch.isinf(loss).item())
|
|
|
|
loss.backward()
|
|
|
|
def check_labels_out_of_vocab(self, config, input_features, *args):
|
|
model = MCTCTForCTC(config)
|
|
model.to(torch_device)
|
|
model.train()
|
|
|
|
input_features = input_features[:3]
|
|
|
|
input_lengths = [input_features.shape[-1] // i for i in [4, 2, 1]]
|
|
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
|
|
labels = ids_tensor((input_features.shape[0], max(max_length_labels) - 2), model.config.vocab_size + 100)
|
|
|
|
with self.parent.assertRaises(ValueError):
|
|
model(input_features, labels=labels)
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config, input_features, attention_mask = self.prepare_config_and_inputs()
|
|
inputs_dict = {"input_features": input_features, "attention_mask": attention_mask}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class MCTCTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
all_model_classes = (MCTCTForCTC, MCTCTModel) if is_torch_available() else ()
|
|
pipeline_model_mapping = (
|
|
{"automatic-speech-recognition": MCTCTForCTC, "feature-extraction": MCTCTModel} if is_torch_available() else {}
|
|
)
|
|
test_pruning = False
|
|
test_headmasking = False
|
|
test_torchscript = False
|
|
|
|
def setUp(self):
|
|
self.model_tester = MCTCTModelTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=MCTCTConfig, hidden_size=37)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
def test_model(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
def test_ctc_loss_inference(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.check_ctc_loss(*config_and_inputs)
|
|
|
|
def test_ctc_train(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.check_ctc_training(*config_and_inputs)
|
|
|
|
def test_labels_out_of_vocab(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
|
|
|
# MCTCT has no inputs_embeds
|
|
def test_inputs_embeds(self):
|
|
pass
|
|
|
|
# `input_ids` is renamed to `input_features`
|
|
def test_forward_signature(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
signature = inspect.signature(model.forward)
|
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
|
arg_names = [*signature.parameters.keys()]
|
|
|
|
expected_arg_names = [
|
|
"input_features",
|
|
"attention_mask",
|
|
"head_mask",
|
|
"output_attentions",
|
|
"output_hidden_states",
|
|
"return_dict",
|
|
]
|
|
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
|
|
|
# MCTCT cannot resize token embeddings
|
|
# since it has no tokens embeddings
|
|
def test_resize_tokens_embeddings(self):
|
|
pass
|
|
|
|
# MCTCT has no inputs_embeds
|
|
def test_model_common_attributes(self):
|
|
pass
|
|
|
|
def test_retain_grad_hidden_states_attentions(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
config.output_hidden_states = True
|
|
config.output_attentions = True
|
|
config.layerdrop = 0.0
|
|
|
|
# no need to test all models as different heads yield the same functionality
|
|
model_class = self.all_model_classes[0]
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
|
|
input_features = inputs_dict["input_features"]
|
|
|
|
input_lengths = torch.tensor(
|
|
[input_features.shape[1] for _ in range(input_features.shape[0])], dtype=torch.long, device=torch_device
|
|
)
|
|
output_lengths = model._get_feat_extract_output_lengths(input_lengths)
|
|
|
|
labels = ids_tensor((input_features.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
|
|
inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
|
|
inputs_dict["labels"] = labels
|
|
|
|
outputs = model(**inputs_dict)
|
|
|
|
output = outputs[0]
|
|
|
|
# Encoder-/Decoder-only models
|
|
hidden_states = outputs.hidden_states[0]
|
|
attentions = outputs.attentions[0]
|
|
|
|
hidden_states.retain_grad()
|
|
attentions.retain_grad()
|
|
|
|
output.flatten()[0].backward(retain_graph=True)
|
|
|
|
self.assertIsNotNone(hidden_states.grad)
|
|
self.assertIsNotNone(attentions.grad)
|
|
|
|
def test_initialization(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
configs_no_init = _config_zero_init(config)
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config=configs_no_init)
|
|
for name, param in model.named_parameters():
|
|
uniform_init_parms = [
|
|
"conv.weight",
|
|
"masked_spec_embed",
|
|
"codevectors",
|
|
"quantizer.weight_proj.weight",
|
|
"project_hid.weight",
|
|
"project_hid.bias",
|
|
"project_q.weight",
|
|
"project_q.bias",
|
|
"feature_projection.projection.weight",
|
|
"feature_projection.projection.bias",
|
|
"objective.weight",
|
|
]
|
|
if param.requires_grad:
|
|
if any(x in name for x in uniform_init_parms):
|
|
self.assertTrue(
|
|
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
|
)
|
|
else:
|
|
self.assertIn(
|
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
|
[0.0, 1.0],
|
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
|
)
|
|
|
|
# overwrite from test_modeling_common
|
|
def _mock_init_weights(self, module):
|
|
if hasattr(module, "weight") and module.weight is not None:
|
|
module.weight.data.fill_(3)
|
|
if hasattr(module, "weight_g") and module.weight_g is not None:
|
|
module.weight_g.data.fill_(3)
|
|
if hasattr(module, "weight_v") and module.weight_v is not None:
|
|
module.weight_v.data.fill_(3)
|
|
if hasattr(module, "bias") and module.bias is not None:
|
|
module.bias.data.fill_(3)
|
|
if hasattr(module, "codevectors") and module.codevectors is not None:
|
|
module.codevectors.data.fill_(3)
|
|
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
|
|
module.masked_spec_embed.data.fill_(3)
|
|
|
|
@slow
|
|
def test_model_from_pretrained(self):
|
|
model = MCTCTModel.from_pretrained("speechbrain/m-ctc-t-large")
|
|
self.assertIsNotNone(model)
|
|
|
|
|
|
@require_torch
|
|
class MCTCTRobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|
all_model_classes = (MCTCTForCTC, MCTCTModel) if is_torch_available() else ()
|
|
test_pruning = False
|
|
test_headmasking = False
|
|
test_torchscript = False
|
|
|
|
def setUp(self):
|
|
self.model_tester = MCTCTModelTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=MCTCTConfig, hidden_size=37)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
def test_model(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
def test_batched_inference(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_batch_inference(*config_and_inputs)
|
|
|
|
def test_ctc_loss_inference(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.check_ctc_loss(*config_and_inputs)
|
|
|
|
def test_ctc_train(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.check_ctc_training(*config_and_inputs)
|
|
|
|
def test_labels_out_of_vocab(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
|
|
|
# MCTCT has no inputs_embeds
|
|
def test_inputs_embeds(self):
|
|
pass
|
|
|
|
# `input_ids` is renamed to `input_features`
|
|
def test_forward_signature(self):
|
|
pass
|
|
|
|
# MCTCT cannot resize token embeddings
|
|
# since it has no tokens embeddings
|
|
def test_resize_tokens_embeddings(self):
|
|
pass
|
|
|
|
# MCTCT has no inputs_embeds
|
|
# and thus the `get_input_embeddings` fn
|
|
# is not implemented
|
|
def test_model_common_attributes(self):
|
|
pass
|
|
|
|
def test_retain_grad_hidden_states_attentions(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
config.output_hidden_states = True
|
|
config.output_attentions = True
|
|
|
|
# no need to test all models as different heads yield the same functionality
|
|
model_class = self.all_model_classes[0]
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
|
|
# set layer drop to 0
|
|
model.config.layerdrop = 0.0
|
|
|
|
input_features = inputs_dict["input_features"]
|
|
|
|
input_lengths = torch.tensor(
|
|
[input_features.shape[1] for _ in range(input_features.shape[0])], dtype=torch.long, device=torch_device
|
|
)
|
|
output_lengths = model._get_feat_extract_output_lengths(input_lengths)
|
|
|
|
labels = ids_tensor((input_features.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
|
|
inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
|
|
inputs_dict["labels"] = labels
|
|
|
|
outputs = model(**inputs_dict)
|
|
|
|
output = outputs[0]
|
|
|
|
# Encoder-/Decoder-only models
|
|
hidden_states = outputs.hidden_states[0]
|
|
attentions = outputs.attentions[0]
|
|
|
|
hidden_states.retain_grad()
|
|
attentions.retain_grad()
|
|
|
|
output.flatten()[0].backward(retain_graph=True)
|
|
|
|
self.assertIsNotNone(hidden_states.grad)
|
|
self.assertIsNotNone(attentions.grad)
|
|
|
|
def test_initialization(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
configs_no_init = _config_zero_init(config)
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config=configs_no_init)
|
|
for name, param in model.named_parameters():
|
|
uniform_init_parms = [
|
|
"conv.weight",
|
|
"masked_spec_embed",
|
|
"codevectors",
|
|
"quantizer.weight_proj.weight",
|
|
"project_hid.weight",
|
|
"project_hid.bias",
|
|
"project_q.weight",
|
|
"project_q.bias",
|
|
"feature_projection.projection.weight",
|
|
"feature_projection.projection.bias",
|
|
"objective.weight",
|
|
]
|
|
if param.requires_grad:
|
|
if any(x in name for x in uniform_init_parms):
|
|
self.assertTrue(
|
|
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
|
)
|
|
else:
|
|
self.assertIn(
|
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
|
[0.0, 1.0],
|
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
|
)
|
|
|
|
# overwrite from test_modeling_common
|
|
def _mock_init_weights(self, module):
|
|
if hasattr(module, "weight") and module.weight is not None:
|
|
module.weight.data.fill_(3)
|
|
if hasattr(module, "weight_g") and module.weight_g is not None:
|
|
module.weight_g.data.fill_(3)
|
|
if hasattr(module, "weight_v") and module.weight_v is not None:
|
|
module.weight_v.data.fill_(3)
|
|
if hasattr(module, "bias") and module.bias is not None:
|
|
module.bias.data.fill_(3)
|
|
if hasattr(module, "codevectors") and module.codevectors is not None:
|
|
module.codevectors.data.fill_(3)
|
|
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
|
|
module.masked_spec_embed.data.fill_(3)
|
|
|
|
@unittest.skip(reason="Feed forward chunking is not implemented")
|
|
def test_feed_forward_chunking(self):
|
|
pass
|
|
|
|
@slow
|
|
def test_model_from_pretrained(self):
|
|
model = MCTCTModel.from_pretrained("speechbrain/m-ctc-t-large")
|
|
self.assertIsNotNone(model)
|
|
|
|
|
|
@require_torch
|
|
@require_soundfile
|
|
@slow
|
|
class MCTCTModelIntegrationTest(unittest.TestCase):
|
|
def _load_datasamples(self, num_samples):
|
|
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
|
# automatic decoding with librispeech
|
|
speech_samples = ds.sort("id").filter(
|
|
lambda x: x["id"] in [f"1272-141231-000{i}" for i in range(num_samples)]
|
|
)[:num_samples]["audio"]
|
|
|
|
return [x["array"] for x in speech_samples]
|
|
|
|
def test_inference_ctc_normal(self):
|
|
model = MCTCTForCTC.from_pretrained("speechbrain/m-ctc-t-large")
|
|
model.to(torch_device)
|
|
processor = MCTCTProcessor.from_pretrained("speechbrain/m-ctc-t-large", do_lower_case=True)
|
|
input_speech = self._load_datasamples(1)
|
|
|
|
input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device)
|
|
|
|
with torch.no_grad():
|
|
logits = model(input_features).logits
|
|
|
|
predicted_ids = torch.argmax(logits, dim=-1)
|
|
predicted_trans = processor.batch_decode(predicted_ids)
|
|
|
|
EXPECTED_TRANSCRIPTIONS = ["a man said to the universe, sir, i exist."]
|
|
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
|
|
|
def test_inference_ctc_normal_batched(self):
|
|
model = MCTCTForCTC.from_pretrained("speechbrain/m-ctc-t-large")
|
|
model.to(torch_device)
|
|
processor = MCTCTProcessor.from_pretrained("speechbrain/m-ctc-t-large", do_lower_case=True)
|
|
|
|
input_speech = self._load_datasamples(2)
|
|
|
|
inputs = processor(input_speech, return_tensors="pt", padding=True)
|
|
|
|
input_features = inputs.input_features.to(torch_device)
|
|
attention_mask = inputs.attention_mask.to(torch_device)
|
|
|
|
with torch.no_grad():
|
|
logits = model(input_features, attention_mask=attention_mask).logits
|
|
|
|
predicted_ids = torch.argmax(logits, dim=-1)
|
|
predicted_trans = processor.batch_decode(predicted_ids)
|
|
|
|
EXPECTED_TRANSCRIPTIONS = [
|
|
"a man said to the universe, sir, i exist.",
|
|
'"sweat-covered brion\'s body, trickling into the tight-lowing clossa was the only germent huor."',
|
|
]
|
|
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
|
|
|
def test_inference_ctc_robust_batched(self):
|
|
model = MCTCTForCTC.from_pretrained("speechbrain/m-ctc-t-large").to(torch_device)
|
|
processor = MCTCTProcessor.from_pretrained("speechbrain/m-ctc-t-large", do_lower_case=True)
|
|
|
|
input_speech = self._load_datasamples(4)
|
|
|
|
inputs = processor(input_speech, return_tensors="pt", padding=True, return_attention_mask=True)
|
|
|
|
input_features = inputs.input_features.to(torch_device)
|
|
attention_mask = inputs.attention_mask.to(torch_device)
|
|
|
|
with torch.no_grad():
|
|
logits = model(input_features, attention_mask=attention_mask).logits
|
|
|
|
predicted_ids = torch.argmax(logits, dim=-1)
|
|
predicted_trans = processor.batch_decode(predicted_ids)
|
|
|
|
EXPECTED_TRANSCRIPTIONS = [
|
|
"a man said to the universe, sir, i exist.",
|
|
'"sweat-covered brion\'s body, trickling into the tight-lowing clossa was the only germent huor." "',
|
|
"\"the cadona's chest still-dripping bloodthe acofis overstrained eyes, even the soring arena around him"
|
|
" with thousands of spectators retrivialities not worth-thinking about.",
|
|
"his instant panic was followed by a small sharp blow high on his chestr.",
|
|
]
|
|
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|