Add SEW CTC models (#14158)

* Add SEW CTC models

* Update paths

* Update paths
This commit is contained in:
Anton Lozhkov 2021-10-27 12:21:09 +03:00 committed by GitHub
parent 1e53faeb2e
commit e1dc5afd28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 106 additions and 116 deletions

View File

@ -164,7 +164,7 @@ class SEWConfig(PretrainedConfig):
mask_time_length=10,
mask_feature_prob=0.0,
mask_feature_length=10,
ctc_loss_reduction="sum",
ctc_loss_reduction="mean",
ctc_zero_infinity=False,
use_weighted_layer_sum=False,
classifier_proj_size=256,

View File

@ -52,7 +52,7 @@ MAPPING = {
"final_layer_norm": "encoder.layers.*.final_layer_norm",
"encoder.upsample.0": "encoder.upsample.projection",
"encoder.layer_norm": "encoder.layer_norm",
"w2v_encoder.layer_norm": "layer_norm",
"w2v_model.layer_norm": "layer_norm",
"w2v_encoder.proj": "lm_head",
"mask_emb": "masked_spec_embed",
}
@ -106,7 +106,7 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
for key, mapped_key in MAPPING.items():
mapped_key = "sew." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key
if key in name or key.split("w2v_encoder.")[-1] == name.split(".")[0]:
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
is_used = True
if "*" in mapped_key:
layer_index = name.split(key)[0].split(".")[-2]
@ -165,13 +165,13 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
unused_weights.append(full_name)
def convert_config(model):
def convert_config(model, is_finetuned):
config = SEWConfig()
if is_finetuned:
fs_config = model.w2v_encoder.w2v_model.cfg
else:
fs_config = model.cfg
config.activation_dropout = fs_config.activation_dropout
config.apply_spec_augment = fs_config.mask_prob > 0 or fs_config.mask_channel_prob > 0
config.attention_dropout = fs_config.attention_dropout
config.conv_bias = fs_config.conv_bias
conv_layers = eval(fs_config.conv_feature_layers)
config.conv_dim = [x[0] for x in conv_layers]
@ -179,19 +179,13 @@ def convert_config(model):
config.conv_stride = [x[2] for x in conv_layers]
config.feat_extract_activation = "gelu"
config.feat_extract_norm = "layer" if fs_config.extractor_mode == "layer_norm" else "group"
config.feat_proj_dropout = fs_config.dropout_input
config.final_dropout = 0.0
config.hidden_act = fs_config.activation_fn.name
config.hidden_dropout = fs_config.dropout
config.hidden_size = fs_config.encoder_embed_dim
config.initializer_range = 0.02
config.intermediate_size = fs_config.encoder_ffn_embed_dim
config.layer_norm_eps = 1e-5
config.layerdrop = fs_config.encoder_layerdrop
config.mask_feature_length = fs_config.mask_channel_length
config.mask_feature_prob = fs_config.mask_channel_prob
config.mask_time_length = fs_config.mask_length
config.mask_time_prob = fs_config.mask_prob
config.num_attention_heads = fs_config.encoder_attention_heads
config.num_conv_pos_embedding_groups = fs_config.conv_pos_groups
config.num_conv_pos_embeddings = fs_config.conv_pos
@ -199,6 +193,24 @@ def convert_config(model):
config.num_hidden_layers = fs_config.encoder_layers
config.squeeze_factor = fs_config.squeeze_factor
# take care of any params that are overridden by the Wav2VecCtc model
if is_finetuned:
fs_config = model.cfg
config.final_dropout = fs_config.final_dropout
config.layerdrop = fs_config.layerdrop
config.activation_dropout = fs_config.activation_dropout
config.apply_spec_augment = fs_config.mask_prob > 0 or fs_config.mask_channel_prob > 0
config.attention_dropout = fs_config.attention_dropout
config.feat_proj_dropout = fs_config.dropout_input
config.hidden_dropout = fs_config.dropout
config.mask_feature_length = fs_config.mask_channel_length
config.mask_feature_prob = fs_config.mask_channel_prob
config.mask_time_length = fs_config.mask_length
config.mask_time_prob = fs_config.mask_prob
config.feature_extractor_type = "Wav2Vec2FeatureExtractor"
config.tokenizer_class = "Wav2Vec2CTCTokenizer"
return config
@ -220,7 +232,7 @@ def convert_sew_checkpoint(
if config_path is not None:
config = SEWConfig.from_pretrained(config_path)
else:
config = convert_config(model[0])
config = convert_config(model[0], is_finetuned)
model = model[0].eval()
return_attention_mask = True if config.feat_extract_norm == "layer" else False
@ -238,6 +250,8 @@ def convert_sew_checkpoint(
# important change bos & pad token id since CTC symbol is <pad> and
# not <s> as in fairseq
target_dict.indices[target_dict.bos_word] = target_dict.pad_index
target_dict.indices[target_dict.pad_word] = target_dict.bos_index
config.bos_token_id = target_dict.pad_index
config.pad_token_id = target_dict.bos_index
config.eos_token_id = target_dict.eos_index

View File

@ -26,12 +26,7 @@ from torch.nn import CrossEntropyLoss
from transformers.deepspeed import is_deepspeed_zero3_enabled
from ...activations import ACT2FN
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...utils import logging
@ -788,6 +783,11 @@ class SEWModel(SEWPreTrainedModel):
self.feature_extractor = SEWFeatureExtractor(config)
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
self.project_features = config.conv_dim[-1] != config.hidden_size
if self.project_features:
self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
self.encoder = SEWEncoder(config)
@ -841,7 +841,13 @@ class SEWModel(SEWPreTrainedModel):
return hidden_states
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
def forward(
self,
input_values,
@ -851,30 +857,6 @@ class SEWModel(SEWPreTrainedModel):
output_hidden_states=None,
return_dict=None,
):
"""
Returns:
Example::
>>> from transformers import Wav2Vec2Processor, SEWModel
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = Wav2Vec2Processor.from_pretrained("asapp/sew-tiny-100k")
>>> model = SEWModel.from_pretrained("asapp/sew-tiny-100k")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
>>> hidden_states = model(input_values).last_hidden_state
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -885,11 +867,15 @@ class SEWModel(SEWPreTrainedModel):
extract_features = extract_features.transpose(1, 2)
extract_features = self.layer_norm(extract_features)
if self.project_features:
extract_features = self.feature_projection(extract_features)
hidden_states = self.feature_dropout(extract_features)
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states = self._mask_hidden_states(extract_features, mask_time_indices=mask_time_indices)
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
encoder_outputs = self.encoder(
hidden_states,

View File

@ -189,7 +189,7 @@ class SEWDConfig(PretrainedConfig):
mask_time_length=10,
mask_feature_prob=0.0,
mask_feature_length=10,
ctc_loss_reduction="sum",
ctc_loss_reduction="mean",
ctc_zero_infinity=False,
use_weighted_layer_sum=False,
classifier_proj_size=256,

View File

@ -54,7 +54,7 @@ MAPPING = {
"encoder.encoder.LayerNorm": "encoder.encoder.LayerNorm",
"encoder.upsample.0": "encoder.upsample.projection",
"encoder.layer_norm": "encoder.layer_norm",
"w2v_encoder.layer_norm": "layer_norm",
"w2v_model.layer_norm": "layer_norm",
"w2v_encoder.proj": "lm_head",
"mask_emb": "masked_spec_embed",
}
@ -91,7 +91,7 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
unused_weights = []
fairseq_dict = fairseq_model.state_dict()
feature_extractor = hf_model.sew.feature_extractor if is_finetuned else hf_model.feature_extractor
feature_extractor = hf_model.sew_d.feature_extractor if is_finetuned else hf_model.feature_extractor
for name, value in fairseq_dict.items():
is_used = False
@ -108,7 +108,7 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
for key, mapped_key in MAPPING.items():
mapped_key = "sew_d." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key
if key in name or key.split("w2v_encoder.")[-1] == name.split(".")[0]:
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
is_used = True
if "*" in mapped_key:
layer_index = name.split(key)[0].split(".")[-2]
@ -169,13 +169,13 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
unused_weights.append(full_name)
def convert_config(model):
def convert_config(model, is_finetuned):
config = SEWDConfig()
if is_finetuned:
fs_config = model.w2v_encoder.w2v_model.cfg
else:
fs_config = model.cfg
config.activation_dropout = fs_config.activation_dropout
config.apply_spec_augment = fs_config.mask_prob > 0 or fs_config.mask_channel_prob > 0
config.attention_dropout = fs_config.attention_dropout
config.conv_bias = fs_config.conv_bias
conv_layers = eval(fs_config.conv_feature_layers)
config.conv_dim = [x[0] for x in conv_layers]
@ -183,19 +183,13 @@ def convert_config(model):
config.conv_stride = [x[2] for x in conv_layers]
config.feat_extract_activation = "gelu"
config.feat_extract_norm = "layer" if fs_config.extractor_mode == "layer_norm" else "group"
config.feat_proj_dropout = fs_config.dropout_input
config.final_dropout = 0.0
config.hidden_act = fs_config.activation_fn.name
config.hidden_dropout = fs_config.dropout
config.hidden_size = fs_config.encoder_embed_dim
config.initializer_range = 0.02
config.intermediate_size = fs_config.encoder_ffn_embed_dim
config.layer_norm_eps = 1e-5
config.layerdrop = fs_config.encoder_layerdrop
config.mask_feature_length = fs_config.mask_channel_length
config.mask_feature_prob = fs_config.mask_channel_prob
config.mask_time_length = fs_config.mask_length
config.mask_time_prob = fs_config.mask_prob
config.num_attention_heads = fs_config.encoder_attention_heads
config.num_conv_pos_embedding_groups = fs_config.conv_pos_groups
config.num_conv_pos_embeddings = fs_config.conv_pos
@ -211,6 +205,24 @@ def convert_config(model):
config.pos_att_type = tuple(fs_config.pos_att_type.split("|"))
config.norm_rel_ebd = fs_config.norm_rel_ebd
# take care of any params that are overridden by the Wav2VecCtc model
if is_finetuned:
fs_config = model.cfg
config.final_dropout = fs_config.final_dropout
config.layerdrop = fs_config.layerdrop
config.activation_dropout = fs_config.activation_dropout
config.apply_spec_augment = fs_config.mask_prob > 0 or fs_config.mask_channel_prob > 0
config.attention_dropout = fs_config.attention_dropout
config.feat_proj_dropout = fs_config.dropout_input
config.hidden_dropout = fs_config.dropout
config.mask_feature_length = fs_config.mask_channel_length
config.mask_feature_prob = fs_config.mask_channel_prob
config.mask_time_length = fs_config.mask_length
config.mask_time_prob = fs_config.mask_prob
config.feature_extractor_type = "Wav2Vec2FeatureExtractor"
config.tokenizer_class = "Wav2Vec2CTCTokenizer"
return config
@ -232,7 +244,7 @@ def convert_sew_checkpoint(
if config_path is not None:
config = SEWDConfig.from_pretrained(config_path)
else:
config = convert_config(model[0])
config = convert_config(model[0], is_finetuned)
model = model[0].eval()
return_attention_mask = True if config.feat_extract_norm == "layer" else False
@ -250,6 +262,8 @@ def convert_sew_checkpoint(
# important change bos & pad token id since CTC symbol is <pad> and
# not <s> as in fairseq
target_dict.indices[target_dict.bos_word] = target_dict.pad_index
target_dict.indices[target_dict.pad_word] = target_dict.bos_index
config.bos_token_id = target_dict.pad_index
config.pad_token_id = target_dict.bos_index
config.eos_token_id = target_dict.eos_index

View File

@ -27,12 +27,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm
from transformers.deepspeed import is_deepspeed_zero3_enabled
from ...activations import ACT2FN
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...utils import logging
@ -1291,12 +1286,16 @@ SEWD_INPUTS_DOCSTRING = r"""
"The bare SEW-D Model transformer outputting raw hidden-states without any specific head on top.",
SEWD_START_DOCSTRING,
)
# Copied from transformers.models.sew.modeling_sew.SEWModel with SEW->SEWD
class SEWDModel(SEWDPreTrainedModel):
def __init__(self, config: SEWDConfig):
super().__init__(config)
self.config = config
self.feature_extractor = SEWDFeatureExtractor(config)
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
self.project_features = config.conv_dim[-1] != config.hidden_size
if self.project_features:
self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
@ -1353,7 +1352,13 @@ class SEWDModel(SEWDPreTrainedModel):
return hidden_states
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
def forward(
self,
input_values,
@ -1363,30 +1368,6 @@ class SEWDModel(SEWDPreTrainedModel):
output_hidden_states=None,
return_dict=None,
):
"""
Returns:
Example::
>>> from transformers import Wav2Vec2Processor, SEWDModel
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = Wav2Vec2Processor.from_pretrained("asapp/sew-tiny-100k")
>>> model = SEWDModel.from_pretrained("asapp/sew-tiny-100k")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
>>> hidden_states = model(input_values).last_hidden_state
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -1397,12 +1378,13 @@ class SEWDModel(SEWDPreTrainedModel):
extract_features = extract_features.transpose(1, 2)
extract_features = self.layer_norm(extract_features)
if self.project_features:
extract_features = self.feature_projection(extract_features)
hidden_states = self.feature_dropout(extract_features)
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
hidden_states = self.feature_projection(extract_features)
hidden_states = self.feature_dropout(hidden_states)
attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)

View File

@ -22,7 +22,7 @@ import pytest
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import SEWConfig, is_torch_available
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, tooslow, torch_device
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, _config_zero_init
@ -531,27 +531,24 @@ class SEWModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 5)
@tooslow
def test_inference_ctc_batched(self):
# TODO: enable this test once the finetuned models are available
model = SEWForCTC.from_pretrained("asapp/sew-tiny-100k-ft-100h").to(torch_device)
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-tiny-100k-ft-100h", do_lower_case=True)
model = SEWForCTC.from_pretrained("asapp/sew-tiny-100k-ft-ls100h").to(torch_device)
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-tiny-100k-ft-ls100h", do_lower_case=True)
input_speech = self._load_datasamples(2)
inputs = processor(input_speech, return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
logits = model(input_values, attention_mask=attention_mask).logits
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
predicted_trans = processor.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
"swet covered brian's body trickling into the tightloine closs hat was the only garment he wore",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)

View File

@ -22,7 +22,7 @@ import pytest
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import SEWDConfig, is_torch_available
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, tooslow, torch_device
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, _config_zero_init
@ -544,27 +544,24 @@ class SEWDModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 5)
@tooslow
def test_inference_ctc_batched(self):
# TODO: enable this test once the finetuned models are available
model = SEWDForCTC.from_pretrained("asapp/sew-d-tiny-100k-ft-100h").to(torch_device)
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-d-tiny-100k-ft-100h", do_lower_case=True)
model = SEWDForCTC.from_pretrained("asapp/sew-d-tiny-100k-ft-ls100h").to(torch_device)
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-d-tiny-100k-ft-ls100h", do_lower_case=True)
input_speech = self._load_datasamples(2)
inputs = processor(input_speech, return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
logits = model(input_values, attention_mask=attention_mask).logits
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
predicted_trans = processor.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
"swet covered breon's body trickling into the titlowing closs that was the only garmened he war",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)