mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[Wav2Vec2] Fix convert (#11562)
* push * small change * correct other typo
This commit is contained in:
parent
623281aa12
commit
c448c01f25
@ -178,9 +178,11 @@ def convert_wav2vec2_checkpoint(
|
||||
if dict_path:
|
||||
target_dict = Dictionary.load(dict_path)
|
||||
|
||||
config.bos_token_id = target_dict.bos_index
|
||||
# important change bos & pad token id since CTC symbol is <pad> and
|
||||
# not <s> as in fairseq
|
||||
config.bos_token_id = target_dict.pad_index
|
||||
config.pad_token_id = target_dict.bos_index
|
||||
config.eos_token_id = target_dict.eos_index
|
||||
config.pad_token_id = target_dict.pad_index
|
||||
config.vocab_size = len(target_dict.symbols)
|
||||
vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json")
|
||||
if not os.path.isdir(pytorch_dump_folder_path):
|
||||
@ -214,9 +216,8 @@ def convert_wav2vec2_checkpoint(
|
||||
hf_wav2vec = Wav2Vec2Model(config)
|
||||
|
||||
if is_finetuned:
|
||||
|
||||
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||
[checkpoint_path], arg_overrides={"data": dict_path}
|
||||
[checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
|
||||
)
|
||||
else:
|
||||
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path])
|
||||
|
Loading…
Reference in New Issue
Block a user