mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Adapt Wav2Vec2 conversion for MMS lang identification (#24234)
* Add conversion for mms lid * make style
This commit is contained in:
parent
4626df5077
commit
c4fec38bc7
@ -32,6 +32,7 @@ from transformers import (
|
||||
Wav2Vec2Processor,
|
||||
logging,
|
||||
)
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2ForSequenceClassification
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
@ -57,6 +58,8 @@ MAPPING = {
|
||||
"final_proj": "project_hid",
|
||||
"w2v_encoder.proj": "lm_head",
|
||||
"mask_emb": "masked_spec_embed",
|
||||
"pooling_layer.linear": "projector",
|
||||
"pooling_layer.projection": "classifier",
|
||||
}
|
||||
TOP_LEVEL_KEYS = [
|
||||
"lm_head",
|
||||
@ -64,9 +67,24 @@ TOP_LEVEL_KEYS = [
|
||||
"quantizer.codevectors",
|
||||
"project_q",
|
||||
"project_hid",
|
||||
"projector",
|
||||
"classifier",
|
||||
]
|
||||
|
||||
|
||||
def read_txt_into_dict(filename):
|
||||
result = {}
|
||||
with open(filename, "r") as file:
|
||||
for line_number, line in enumerate(file):
|
||||
line = line.strip()
|
||||
if line:
|
||||
words = line.split()
|
||||
key = line_number
|
||||
value = words[0]
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def set_recursively(key, value, full_name, weight_type, hf_pointer):
|
||||
for attribute in key.split("."):
|
||||
hf_pointer = getattr(hf_pointer, attribute)
|
||||
@ -240,7 +258,7 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_wav2vec2_checkpoint(
|
||||
checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
|
||||
checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True, is_seq_class=False
|
||||
):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
@ -250,7 +268,20 @@ def convert_wav2vec2_checkpoint(
|
||||
else:
|
||||
config = Wav2Vec2Config()
|
||||
|
||||
if is_finetuned:
|
||||
if is_seq_class:
|
||||
id2label = read_txt_into_dict(dict_path)
|
||||
config.id2label = id2label
|
||||
hf_wav2vec = Wav2Vec2ForSequenceClassification(config)
|
||||
feature_extractor = Wav2Vec2FeatureExtractor(
|
||||
feature_size=1,
|
||||
sampling_rate=16000,
|
||||
padding_value=0,
|
||||
do_normalize=True,
|
||||
return_attention_mask=True,
|
||||
)
|
||||
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
elif is_finetuned:
|
||||
if dict_path:
|
||||
target_dict = Dictionary.load(dict_path)
|
||||
|
||||
@ -296,7 +327,7 @@ def convert_wav2vec2_checkpoint(
|
||||
else:
|
||||
hf_wav2vec = Wav2Vec2ForPreTraining(config)
|
||||
|
||||
if is_finetuned:
|
||||
if is_finetuned or is_seq_class:
|
||||
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||
[checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
|
||||
)
|
||||
@ -322,7 +353,19 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_wav2vec2_checkpoint(
|
||||
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned
|
||||
parser.add_argument(
|
||||
"--is_seq_class",
|
||||
action="store_true",
|
||||
help="Whether the model to convert is a fine-tuned sequence classification model or not",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
is_finetuned = not args.not_finetuned and not args.is_seq_class
|
||||
convert_wav2vec2_checkpoint(
|
||||
args.checkpoint_path,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.config_path,
|
||||
args.dict_path,
|
||||
is_finetuned,
|
||||
args.is_seq_class,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user