Adapt Wav2Vec2 conversion for MMS lang identification (#24234)

* Add conversion for mms lid

* make style
This commit is contained in:
Patrick von Platen 2023-06-14 16:02:36 +02:00 committed by GitHub
parent 4626df5077
commit c4fec38bc7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -32,6 +32,7 @@ from transformers import (
Wav2Vec2Processor,
logging,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2ForSequenceClassification
logging.set_verbosity_info()
@ -57,6 +58,8 @@ MAPPING = {
"final_proj": "project_hid",
"w2v_encoder.proj": "lm_head",
"mask_emb": "masked_spec_embed",
"pooling_layer.linear": "projector",
"pooling_layer.projection": "classifier",
}
TOP_LEVEL_KEYS = [
"lm_head",
@ -64,9 +67,24 @@ TOP_LEVEL_KEYS = [
"quantizer.codevectors",
"project_q",
"project_hid",
"projector",
"classifier",
]
def read_txt_into_dict(filename):
result = {}
with open(filename, "r") as file:
for line_number, line in enumerate(file):
line = line.strip()
if line:
words = line.split()
key = line_number
value = words[0]
result[key] = value
return result
def set_recursively(key, value, full_name, weight_type, hf_pointer):
for attribute in key.split("."):
hf_pointer = getattr(hf_pointer, attribute)
@ -240,7 +258,7 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
@torch.no_grad()
def convert_wav2vec2_checkpoint(
checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True, is_seq_class=False
):
"""
Copy/paste/tweak model's weights to transformers design.
@ -250,7 +268,20 @@ def convert_wav2vec2_checkpoint(
else:
config = Wav2Vec2Config()
if is_finetuned:
if is_seq_class:
id2label = read_txt_into_dict(dict_path)
config.id2label = id2label
hf_wav2vec = Wav2Vec2ForSequenceClassification(config)
feature_extractor = Wav2Vec2FeatureExtractor(
feature_size=1,
sampling_rate=16000,
padding_value=0,
do_normalize=True,
return_attention_mask=True,
)
feature_extractor.save_pretrained(pytorch_dump_folder_path)
elif is_finetuned:
if dict_path:
target_dict = Dictionary.load(dict_path)
@ -296,7 +327,7 @@ def convert_wav2vec2_checkpoint(
else:
hf_wav2vec = Wav2Vec2ForPreTraining(config)
if is_finetuned:
if is_finetuned or is_seq_class:
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
)
@ -322,7 +353,19 @@ if __name__ == "__main__":
parser.add_argument(
"--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
)
args = parser.parse_args()
convert_wav2vec2_checkpoint(
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned
parser.add_argument(
"--is_seq_class",
action="store_true",
help="Whether the model to convert is a fine-tuned sequence classification model or not",
)
args = parser.parse_args()
is_finetuned = not args.not_finetuned and not args.is_seq_class
convert_wav2vec2_checkpoint(
args.checkpoint_path,
args.pytorch_dump_folder_path,
args.config_path,
args.dict_path,
is_finetuned,
args.is_seq_class,
)