mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-18 20:18:24 +06:00
[XLSR-Wav2Vec2] Add multi-lingual Wav2Vec2 models (#10648)
* add conversion script * add wav2vec2 xslr models * finish * Update docs/source/model_doc/xlsr_wav2vec2.rst Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
63c295ac05
commit
602d63f05c
@ -237,6 +237,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
|||||||
1. **[XLM-ProphetNet](https://huggingface.co/transformers/model_doc/xlmprophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
1. **[XLM-ProphetNet](https://huggingface.co/transformers/model_doc/xlmprophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||||
1. **[XLM-RoBERTa](https://huggingface.co/transformers/model_doc/xlmroberta.html)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov.
|
1. **[XLM-RoBERTa](https://huggingface.co/transformers/model_doc/xlmroberta.html)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov.
|
||||||
1. **[XLNet](https://huggingface.co/transformers/model_doc/xlnet.html)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
1. **[XLNet](https://huggingface.co/transformers/model_doc/xlnet.html)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||||
|
1. **[XLSR-Wav2Vec2](https://huggingface.co/transformers/model_doc/xlsr_wav2vec2.html)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||||
1. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
|
1. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
|
||||||
|
|
||||||
To check if each model has an implementation in PyTorch/TensorFlow/Flax or has an associated tokenizer backed by the 🤗 Tokenizers library, refer to [this table](https://huggingface.co/transformers/index.html#bigtable)
|
To check if each model has an implementation in PyTorch/TensorFlow/Flax or has an associated tokenizer backed by the 🤗 Tokenizers library, refer to [this table](https://huggingface.co/transformers/index.html#bigtable)
|
||||||
|
@ -221,6 +221,9 @@ and conversion utilities for the following models:
|
|||||||
46. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
46. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||||
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
|
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
|
||||||
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||||
|
47. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
|
||||||
|
Cross-Lingual Representation Learning For Speech Recognition <https://arxiv.org/abs/2006.13979>`__ by Alexis
|
||||||
|
Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||||
|
|
||||||
|
|
||||||
.. _bigtable:
|
.. _bigtable:
|
||||||
@ -451,6 +454,7 @@ TensorFlow and/or Flax.
|
|||||||
model_doc/xlmprophetnet
|
model_doc/xlmprophetnet
|
||||||
model_doc/xlmroberta
|
model_doc/xlmroberta
|
||||||
model_doc/xlnet
|
model_doc/xlnet
|
||||||
|
model_doc/xlsr_wav2vec2
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
45
docs/source/model_doc/xlsr_wav2vec2.rst
Normal file
45
docs/source/model_doc/xlsr_wav2vec2.rst
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
..
|
||||||
|
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
XLSR-Wav2Vec2
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
Overview
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
The XLSR-Wav2Vec2 model was proposed in `Unsupervised Cross-Lingual Representation Learning For Speech Recognition
|
||||||
|
<https://arxiv.org/abs/2006.13979>`__ by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael
|
||||||
|
Auli.
|
||||||
|
|
||||||
|
The abstract from the paper is the following:
|
||||||
|
|
||||||
|
*This paper presents XLSR which learns cross-lingual speech representations by pretraining a single model from the raw
|
||||||
|
waveform of speech in multiple languages. We build on wav2vec 2.0 which is trained by solving a contrastive task over
|
||||||
|
masked latent speech representations and jointly learns a quantization of the latents shared across languages. The
|
||||||
|
resulting model is fine-tuned on labeled data and experiments show that cross-lingual pretraining significantly
|
||||||
|
outperforms monolingual pretraining. On the CommonVoice benchmark, XLSR shows a relative phoneme error rate reduction
|
||||||
|
of 72% compared to the best known results. On BABEL, our approach improves word error rate by 16% relative compared to
|
||||||
|
a comparable system. Our approach enables a single multilingual speech recognition model which is competitive to strong
|
||||||
|
individual models. Analysis shows that the latent discrete speech representations are shared across languages with
|
||||||
|
increased sharing for related languages. We hope to catalyze research in low-resource speech understanding by releasing
|
||||||
|
XLSR-53, a large model pretrained in 53 languages.*
|
||||||
|
|
||||||
|
Tips:
|
||||||
|
|
||||||
|
- XLSR-Wav2Vec2 is a speech model that accepts a float array corresponding to the raw waveform of the speech signal.
|
||||||
|
- XLSR-Wav2Vec2 model was trained using connectionist temporal classification (CTC) so the model output has to be
|
||||||
|
decoded using :class:`~transformers.Wav2Vec2CTCTokenizer`.
|
||||||
|
|
||||||
|
XLSR-Wav2Vec2's architecture is based on the Wav2Vec2 model, so one can refer to :doc:`Wav2Vec2's documentation page
|
||||||
|
<wav2vec2>`.
|
||||||
|
|
||||||
|
The original code can be found `here <https://github.com/pytorch/fairseq/tree/master/fairseq/models/wav2vec>`__.
|
@ -90,7 +90,8 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
|
|||||||
else:
|
else:
|
||||||
for key, mapped_key in MAPPING.items():
|
for key, mapped_key in MAPPING.items():
|
||||||
mapped_key = "wav2vec2." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key
|
mapped_key = "wav2vec2." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key
|
||||||
if key in name:
|
|
||||||
|
if key in name or (key.split("w2v_model.")[-1] == name.split(".")[0] and not is_finetuned):
|
||||||
is_used = True
|
is_used = True
|
||||||
if "*" in mapped_key:
|
if "*" in mapped_key:
|
||||||
layer_index = name.split(key)[0].split(".")[-2]
|
layer_index = name.split(key)[0].split(".")[-2]
|
||||||
@ -110,7 +111,7 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
|
|||||||
if not is_used:
|
if not is_used:
|
||||||
unused_weights.append(name)
|
unused_weights.append(name)
|
||||||
|
|
||||||
logger.info("Unused weights", unused_weights)
|
logger.warn(f"Unused weights: {unused_weights}")
|
||||||
|
|
||||||
|
|
||||||
def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
|
def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
|
||||||
|
Loading…
Reference in New Issue
Block a user