mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Add SEW CTC models (#14158)
* Add SEW CTC models * Update paths * Update paths
This commit is contained in:
parent
1e53faeb2e
commit
e1dc5afd28
@ -164,7 +164,7 @@ class SEWConfig(PretrainedConfig):
|
||||
mask_time_length=10,
|
||||
mask_feature_prob=0.0,
|
||||
mask_feature_length=10,
|
||||
ctc_loss_reduction="sum",
|
||||
ctc_loss_reduction="mean",
|
||||
ctc_zero_infinity=False,
|
||||
use_weighted_layer_sum=False,
|
||||
classifier_proj_size=256,
|
||||
|
@ -52,7 +52,7 @@ MAPPING = {
|
||||
"final_layer_norm": "encoder.layers.*.final_layer_norm",
|
||||
"encoder.upsample.0": "encoder.upsample.projection",
|
||||
"encoder.layer_norm": "encoder.layer_norm",
|
||||
"w2v_encoder.layer_norm": "layer_norm",
|
||||
"w2v_model.layer_norm": "layer_norm",
|
||||
"w2v_encoder.proj": "lm_head",
|
||||
"mask_emb": "masked_spec_embed",
|
||||
}
|
||||
@ -106,7 +106,7 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
|
||||
for key, mapped_key in MAPPING.items():
|
||||
mapped_key = "sew." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key
|
||||
|
||||
if key in name or key.split("w2v_encoder.")[-1] == name.split(".")[0]:
|
||||
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
|
||||
is_used = True
|
||||
if "*" in mapped_key:
|
||||
layer_index = name.split(key)[0].split(".")[-2]
|
||||
@ -165,13 +165,13 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
|
||||
unused_weights.append(full_name)
|
||||
|
||||
|
||||
def convert_config(model):
|
||||
def convert_config(model, is_finetuned):
|
||||
config = SEWConfig()
|
||||
fs_config = model.cfg
|
||||
if is_finetuned:
|
||||
fs_config = model.w2v_encoder.w2v_model.cfg
|
||||
else:
|
||||
fs_config = model.cfg
|
||||
|
||||
config.activation_dropout = fs_config.activation_dropout
|
||||
config.apply_spec_augment = fs_config.mask_prob > 0 or fs_config.mask_channel_prob > 0
|
||||
config.attention_dropout = fs_config.attention_dropout
|
||||
config.conv_bias = fs_config.conv_bias
|
||||
conv_layers = eval(fs_config.conv_feature_layers)
|
||||
config.conv_dim = [x[0] for x in conv_layers]
|
||||
@ -179,19 +179,13 @@ def convert_config(model):
|
||||
config.conv_stride = [x[2] for x in conv_layers]
|
||||
config.feat_extract_activation = "gelu"
|
||||
config.feat_extract_norm = "layer" if fs_config.extractor_mode == "layer_norm" else "group"
|
||||
config.feat_proj_dropout = fs_config.dropout_input
|
||||
config.final_dropout = 0.0
|
||||
config.hidden_act = fs_config.activation_fn.name
|
||||
config.hidden_dropout = fs_config.dropout
|
||||
config.hidden_size = fs_config.encoder_embed_dim
|
||||
config.initializer_range = 0.02
|
||||
config.intermediate_size = fs_config.encoder_ffn_embed_dim
|
||||
config.layer_norm_eps = 1e-5
|
||||
config.layerdrop = fs_config.encoder_layerdrop
|
||||
config.mask_feature_length = fs_config.mask_channel_length
|
||||
config.mask_feature_prob = fs_config.mask_channel_prob
|
||||
config.mask_time_length = fs_config.mask_length
|
||||
config.mask_time_prob = fs_config.mask_prob
|
||||
config.num_attention_heads = fs_config.encoder_attention_heads
|
||||
config.num_conv_pos_embedding_groups = fs_config.conv_pos_groups
|
||||
config.num_conv_pos_embeddings = fs_config.conv_pos
|
||||
@ -199,6 +193,24 @@ def convert_config(model):
|
||||
config.num_hidden_layers = fs_config.encoder_layers
|
||||
config.squeeze_factor = fs_config.squeeze_factor
|
||||
|
||||
# take care of any params that are overridden by the Wav2VecCtc model
|
||||
if is_finetuned:
|
||||
fs_config = model.cfg
|
||||
config.final_dropout = fs_config.final_dropout
|
||||
config.layerdrop = fs_config.layerdrop
|
||||
config.activation_dropout = fs_config.activation_dropout
|
||||
config.apply_spec_augment = fs_config.mask_prob > 0 or fs_config.mask_channel_prob > 0
|
||||
config.attention_dropout = fs_config.attention_dropout
|
||||
config.feat_proj_dropout = fs_config.dropout_input
|
||||
config.hidden_dropout = fs_config.dropout
|
||||
config.mask_feature_length = fs_config.mask_channel_length
|
||||
config.mask_feature_prob = fs_config.mask_channel_prob
|
||||
config.mask_time_length = fs_config.mask_length
|
||||
config.mask_time_prob = fs_config.mask_prob
|
||||
|
||||
config.feature_extractor_type = "Wav2Vec2FeatureExtractor"
|
||||
config.tokenizer_class = "Wav2Vec2CTCTokenizer"
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@ -220,7 +232,7 @@ def convert_sew_checkpoint(
|
||||
if config_path is not None:
|
||||
config = SEWConfig.from_pretrained(config_path)
|
||||
else:
|
||||
config = convert_config(model[0])
|
||||
config = convert_config(model[0], is_finetuned)
|
||||
model = model[0].eval()
|
||||
|
||||
return_attention_mask = True if config.feat_extract_norm == "layer" else False
|
||||
@ -238,6 +250,8 @@ def convert_sew_checkpoint(
|
||||
|
||||
# important change bos & pad token id since CTC symbol is <pad> and
|
||||
# not <s> as in fairseq
|
||||
target_dict.indices[target_dict.bos_word] = target_dict.pad_index
|
||||
target_dict.indices[target_dict.pad_word] = target_dict.bos_index
|
||||
config.bos_token_id = target_dict.pad_index
|
||||
config.pad_token_id = target_dict.bos_index
|
||||
config.eos_token_id = target_dict.eos_index
|
||||
|
@ -26,12 +26,7 @@ from torch.nn import CrossEntropyLoss
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
@ -788,6 +783,11 @@ class SEWModel(SEWPreTrainedModel):
|
||||
self.feature_extractor = SEWFeatureExtractor(config)
|
||||
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
|
||||
|
||||
self.project_features = config.conv_dim[-1] != config.hidden_size
|
||||
if self.project_features:
|
||||
self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
|
||||
self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
|
||||
|
||||
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
|
||||
|
||||
self.encoder = SEWEncoder(config)
|
||||
@ -841,7 +841,13 @@ class SEWModel(SEWPreTrainedModel):
|
||||
return hidden_states
|
||||
|
||||
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_PROCESSOR_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=BaseModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
@ -851,30 +857,6 @@ class SEWModel(SEWPreTrainedModel):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
"""
|
||||
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import Wav2Vec2Processor, SEWModel
|
||||
>>> from datasets import load_dataset
|
||||
>>> import soundfile as sf
|
||||
|
||||
>>> processor = Wav2Vec2Processor.from_pretrained("asapp/sew-tiny-100k")
|
||||
>>> model = SEWModel.from_pretrained("asapp/sew-tiny-100k")
|
||||
|
||||
>>> def map_to_array(batch):
|
||||
... speech, _ = sf.read(batch["file"])
|
||||
... batch["speech"] = speech
|
||||
... return batch
|
||||
|
||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.map(map_to_array)
|
||||
|
||||
>>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
||||
>>> hidden_states = model(input_values).last_hidden_state
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
@ -885,11 +867,15 @@ class SEWModel(SEWPreTrainedModel):
|
||||
extract_features = extract_features.transpose(1, 2)
|
||||
extract_features = self.layer_norm(extract_features)
|
||||
|
||||
if self.project_features:
|
||||
extract_features = self.feature_projection(extract_features)
|
||||
hidden_states = self.feature_dropout(extract_features)
|
||||
|
||||
if attention_mask is not None:
|
||||
# compute reduced attention_mask corresponding to feature vectors
|
||||
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
||||
attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||||
|
||||
hidden_states = self._mask_hidden_states(extract_features, mask_time_indices=mask_time_indices)
|
||||
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
hidden_states,
|
||||
|
@ -189,7 +189,7 @@ class SEWDConfig(PretrainedConfig):
|
||||
mask_time_length=10,
|
||||
mask_feature_prob=0.0,
|
||||
mask_feature_length=10,
|
||||
ctc_loss_reduction="sum",
|
||||
ctc_loss_reduction="mean",
|
||||
ctc_zero_infinity=False,
|
||||
use_weighted_layer_sum=False,
|
||||
classifier_proj_size=256,
|
||||
|
@ -54,7 +54,7 @@ MAPPING = {
|
||||
"encoder.encoder.LayerNorm": "encoder.encoder.LayerNorm",
|
||||
"encoder.upsample.0": "encoder.upsample.projection",
|
||||
"encoder.layer_norm": "encoder.layer_norm",
|
||||
"w2v_encoder.layer_norm": "layer_norm",
|
||||
"w2v_model.layer_norm": "layer_norm",
|
||||
"w2v_encoder.proj": "lm_head",
|
||||
"mask_emb": "masked_spec_embed",
|
||||
}
|
||||
@ -91,7 +91,7 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
|
||||
unused_weights = []
|
||||
fairseq_dict = fairseq_model.state_dict()
|
||||
|
||||
feature_extractor = hf_model.sew.feature_extractor if is_finetuned else hf_model.feature_extractor
|
||||
feature_extractor = hf_model.sew_d.feature_extractor if is_finetuned else hf_model.feature_extractor
|
||||
|
||||
for name, value in fairseq_dict.items():
|
||||
is_used = False
|
||||
@ -108,7 +108,7 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
|
||||
for key, mapped_key in MAPPING.items():
|
||||
mapped_key = "sew_d." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key
|
||||
|
||||
if key in name or key.split("w2v_encoder.")[-1] == name.split(".")[0]:
|
||||
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
|
||||
is_used = True
|
||||
if "*" in mapped_key:
|
||||
layer_index = name.split(key)[0].split(".")[-2]
|
||||
@ -169,13 +169,13 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
|
||||
unused_weights.append(full_name)
|
||||
|
||||
|
||||
def convert_config(model):
|
||||
def convert_config(model, is_finetuned):
|
||||
config = SEWDConfig()
|
||||
fs_config = model.cfg
|
||||
if is_finetuned:
|
||||
fs_config = model.w2v_encoder.w2v_model.cfg
|
||||
else:
|
||||
fs_config = model.cfg
|
||||
|
||||
config.activation_dropout = fs_config.activation_dropout
|
||||
config.apply_spec_augment = fs_config.mask_prob > 0 or fs_config.mask_channel_prob > 0
|
||||
config.attention_dropout = fs_config.attention_dropout
|
||||
config.conv_bias = fs_config.conv_bias
|
||||
conv_layers = eval(fs_config.conv_feature_layers)
|
||||
config.conv_dim = [x[0] for x in conv_layers]
|
||||
@ -183,19 +183,13 @@ def convert_config(model):
|
||||
config.conv_stride = [x[2] for x in conv_layers]
|
||||
config.feat_extract_activation = "gelu"
|
||||
config.feat_extract_norm = "layer" if fs_config.extractor_mode == "layer_norm" else "group"
|
||||
config.feat_proj_dropout = fs_config.dropout_input
|
||||
config.final_dropout = 0.0
|
||||
config.hidden_act = fs_config.activation_fn.name
|
||||
config.hidden_dropout = fs_config.dropout
|
||||
config.hidden_size = fs_config.encoder_embed_dim
|
||||
config.initializer_range = 0.02
|
||||
config.intermediate_size = fs_config.encoder_ffn_embed_dim
|
||||
config.layer_norm_eps = 1e-5
|
||||
config.layerdrop = fs_config.encoder_layerdrop
|
||||
config.mask_feature_length = fs_config.mask_channel_length
|
||||
config.mask_feature_prob = fs_config.mask_channel_prob
|
||||
config.mask_time_length = fs_config.mask_length
|
||||
config.mask_time_prob = fs_config.mask_prob
|
||||
config.num_attention_heads = fs_config.encoder_attention_heads
|
||||
config.num_conv_pos_embedding_groups = fs_config.conv_pos_groups
|
||||
config.num_conv_pos_embeddings = fs_config.conv_pos
|
||||
@ -211,6 +205,24 @@ def convert_config(model):
|
||||
config.pos_att_type = tuple(fs_config.pos_att_type.split("|"))
|
||||
config.norm_rel_ebd = fs_config.norm_rel_ebd
|
||||
|
||||
# take care of any params that are overridden by the Wav2VecCtc model
|
||||
if is_finetuned:
|
||||
fs_config = model.cfg
|
||||
config.final_dropout = fs_config.final_dropout
|
||||
config.layerdrop = fs_config.layerdrop
|
||||
config.activation_dropout = fs_config.activation_dropout
|
||||
config.apply_spec_augment = fs_config.mask_prob > 0 or fs_config.mask_channel_prob > 0
|
||||
config.attention_dropout = fs_config.attention_dropout
|
||||
config.feat_proj_dropout = fs_config.dropout_input
|
||||
config.hidden_dropout = fs_config.dropout
|
||||
config.mask_feature_length = fs_config.mask_channel_length
|
||||
config.mask_feature_prob = fs_config.mask_channel_prob
|
||||
config.mask_time_length = fs_config.mask_length
|
||||
config.mask_time_prob = fs_config.mask_prob
|
||||
|
||||
config.feature_extractor_type = "Wav2Vec2FeatureExtractor"
|
||||
config.tokenizer_class = "Wav2Vec2CTCTokenizer"
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@ -232,7 +244,7 @@ def convert_sew_checkpoint(
|
||||
if config_path is not None:
|
||||
config = SEWDConfig.from_pretrained(config_path)
|
||||
else:
|
||||
config = convert_config(model[0])
|
||||
config = convert_config(model[0], is_finetuned)
|
||||
model = model[0].eval()
|
||||
|
||||
return_attention_mask = True if config.feat_extract_norm == "layer" else False
|
||||
@ -250,6 +262,8 @@ def convert_sew_checkpoint(
|
||||
|
||||
# important change bos & pad token id since CTC symbol is <pad> and
|
||||
# not <s> as in fairseq
|
||||
target_dict.indices[target_dict.bos_word] = target_dict.pad_index
|
||||
target_dict.indices[target_dict.pad_word] = target_dict.bos_index
|
||||
config.bos_token_id = target_dict.pad_index
|
||||
config.pad_token_id = target_dict.bos_index
|
||||
config.eos_token_id = target_dict.eos_index
|
||||
|
@ -27,12 +27,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
@ -1291,13 +1286,17 @@ SEWD_INPUTS_DOCSTRING = r"""
|
||||
"The bare SEW-D Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
SEWD_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.sew.modeling_sew.SEWModel with SEW->SEWD
|
||||
class SEWDModel(SEWDPreTrainedModel):
|
||||
def __init__(self, config: SEWDConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.feature_extractor = SEWDFeatureExtractor(config)
|
||||
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
|
||||
self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
|
||||
|
||||
self.project_features = config.conv_dim[-1] != config.hidden_size
|
||||
if self.project_features:
|
||||
self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
|
||||
self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
|
||||
|
||||
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
|
||||
@ -1353,7 +1352,13 @@ class SEWDModel(SEWDPreTrainedModel):
|
||||
return hidden_states
|
||||
|
||||
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_PROCESSOR_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=BaseModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
@ -1363,30 +1368,6 @@ class SEWDModel(SEWDPreTrainedModel):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
"""
|
||||
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import Wav2Vec2Processor, SEWDModel
|
||||
>>> from datasets import load_dataset
|
||||
>>> import soundfile as sf
|
||||
|
||||
>>> processor = Wav2Vec2Processor.from_pretrained("asapp/sew-tiny-100k")
|
||||
>>> model = SEWDModel.from_pretrained("asapp/sew-tiny-100k")
|
||||
|
||||
>>> def map_to_array(batch):
|
||||
... speech, _ = sf.read(batch["file"])
|
||||
... batch["speech"] = speech
|
||||
... return batch
|
||||
|
||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.map(map_to_array)
|
||||
|
||||
>>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
||||
>>> hidden_states = model(input_values).last_hidden_state
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
@ -1397,12 +1378,13 @@ class SEWDModel(SEWDPreTrainedModel):
|
||||
extract_features = extract_features.transpose(1, 2)
|
||||
extract_features = self.layer_norm(extract_features)
|
||||
|
||||
if self.project_features:
|
||||
extract_features = self.feature_projection(extract_features)
|
||||
hidden_states = self.feature_dropout(extract_features)
|
||||
|
||||
if attention_mask is not None:
|
||||
# compute reduced attention_mask corresponding to feature vectors
|
||||
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
||||
|
||||
hidden_states = self.feature_projection(extract_features)
|
||||
hidden_states = self.feature_dropout(hidden_states)
|
||||
attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||||
|
||||
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
|
||||
|
||||
|
@ -22,7 +22,7 @@ import pytest
|
||||
|
||||
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from transformers import SEWConfig, is_torch_available
|
||||
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, tooslow, torch_device
|
||||
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||
@ -531,27 +531,24 @@ class SEWModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
|
||||
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 5)
|
||||
|
||||
@tooslow
|
||||
def test_inference_ctc_batched(self):
|
||||
# TODO: enable this test once the finetuned models are available
|
||||
model = SEWForCTC.from_pretrained("asapp/sew-tiny-100k-ft-100h").to(torch_device)
|
||||
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-tiny-100k-ft-100h", do_lower_case=True)
|
||||
model = SEWForCTC.from_pretrained("asapp/sew-tiny-100k-ft-ls100h").to(torch_device)
|
||||
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-tiny-100k-ft-ls100h", do_lower_case=True)
|
||||
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
inputs = processor(input_speech, return_tensors="pt", padding=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_values, attention_mask=attention_mask).logits
|
||||
logits = model(input_values).logits
|
||||
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
predicted_trans = processor.batch_decode(predicted_ids)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = [
|
||||
"a man said to the universe sir i exist",
|
||||
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
|
||||
"swet covered brian's body trickling into the tightloine closs hat was the only garment he wore",
|
||||
]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
@ -22,7 +22,7 @@ import pytest
|
||||
|
||||
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from transformers import SEWDConfig, is_torch_available
|
||||
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, tooslow, torch_device
|
||||
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||
@ -544,27 +544,24 @@ class SEWDModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
|
||||
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 5)
|
||||
|
||||
@tooslow
|
||||
def test_inference_ctc_batched(self):
|
||||
# TODO: enable this test once the finetuned models are available
|
||||
model = SEWDForCTC.from_pretrained("asapp/sew-d-tiny-100k-ft-100h").to(torch_device)
|
||||
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-d-tiny-100k-ft-100h", do_lower_case=True)
|
||||
model = SEWDForCTC.from_pretrained("asapp/sew-d-tiny-100k-ft-ls100h").to(torch_device)
|
||||
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-d-tiny-100k-ft-ls100h", do_lower_case=True)
|
||||
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
inputs = processor(input_speech, return_tensors="pt", padding=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_values, attention_mask=attention_mask).logits
|
||||
logits = model(input_values).logits
|
||||
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
predicted_trans = processor.batch_decode(predicted_ids)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = [
|
||||
"a man said to the universe sir i exist",
|
||||
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
|
||||
"swet covered breon's body trickling into the titlowing closs that was the only garmened he war",
|
||||
]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
Loading…
Reference in New Issue
Block a user