[Wav2Vec2] Fix convert (#11562)

* push

* small change

* correct other typo
This commit is contained in:
Patrick von Platen 2021-05-03 11:53:30 +02:00 committed by GitHub
parent 623281aa12
commit c448c01f25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -178,9 +178,11 @@ def convert_wav2vec2_checkpoint(
if dict_path:
target_dict = Dictionary.load(dict_path)
config.bos_token_id = target_dict.bos_index
# important change bos & pad token id since CTC symbol is <pad> and
# not <s> as in fairseq
config.bos_token_id = target_dict.pad_index
config.pad_token_id = target_dict.bos_index
config.eos_token_id = target_dict.eos_index
config.pad_token_id = target_dict.pad_index
config.vocab_size = len(target_dict.symbols)
vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json")
if not os.path.isdir(pytorch_dump_folder_path):
@ -214,9 +216,8 @@ def convert_wav2vec2_checkpoint(
hf_wav2vec = Wav2Vec2Model(config)
if is_finetuned:
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[checkpoint_path], arg_overrides={"data": dict_path}
[checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
)
else:
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path])