Minor fixes to XTREME-S (#16193)

* Minor fixes

* Fix vocab union

* Update examples/research_projects/xtreme-s/README.md

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update README

* unused import

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Anton Lozhkov 2022-03-16 17:23:00 +04:00 committed by GitHub
parent 8cc925a241
commit d35e0c6247
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 27 deletions

View File

@ -20,7 +20,7 @@ limitations under the License.
The Cross-lingual TRansfer Evaluation of Multilingual Encoders for Speech (XTREME-S) benchmark is a benchmark designed to evaluate speech representations across languages, tasks, domains and data regimes. It covers XX typologically diverse languages and seven downstream tasks grouped in four families: speech recognition, translation, classification and retrieval. The Cross-lingual TRansfer Evaluation of Multilingual Encoders for Speech (XTREME-S) benchmark is a benchmark designed to evaluate speech representations across languages, tasks, domains and data regimes. It covers XX typologically diverse languages and seven downstream tasks grouped in four families: speech recognition, translation, classification and retrieval.
XTREME-S covers speech recognition with BABEL, Multilingual LibriSpeech (MLS) and VoxPopuli, speech translation with CoVoST-2, speech classification with LangID (FLoRes) and intent classification (MInds-14) and finally speech retrieval with speech-speech translation data mining (bi-speech retrieval). Each of the tasks covers a subset of the 40 languages included in XTREME-S (shown here with their ISO 639-1 codes): ar, as, ca, cs, cy, da, de, en, en, en, en, es, et, fa, fi, fr, hr, hu, id, it, ja, ka, ko, lo, lt, lv, mn, nl, pl, pt, ro, ru, sk, sl, sv, sw, ta, tl, tr and zh. XTREME-S covers speech recognition with Fleurs, Multilingual LibriSpeech (MLS) and VoxPopuli, speech translation with CoVoST-2, speech classification with LangID (Fleurs) and intent classification (MInds-14) and finally speech(-text) retrieval with Fleurs. Each of the tasks covers a subset of the 102 languages included in XTREME-S (shown here with their ISO 3166-1 codes): afr, amh, ara, asm, ast, azj, bel, ben, bos, cat, ceb, zho_simpl, zho_trad, ces, cym, dan, deu, ell, eng, spa, est, fas, ful, fin, tgl, fra, gle, glg, guj, hau, heb, hin, hrv, hun, hye, ind, ibo, isl, ita, jpn, jav, kat, kam, kea, kaz, khm, kan, kor, ckb, kir, ltz, lug, lin, lao, lit, luo, lav, mri, mkd, mal, mon, mar, msa, mlt, mya, nob, npi, nld, nso, nya, oci, orm, ory, pan, pol, pus, por, ron, rus, bul, snd, slk, slv, sna, som, srp, swe, swh, tam, tel, tgk, tha, tur, ukr, umb, urd, uzb, vie, wol, xho, yor and zul.
Paper: `<TODO>` Paper: `<TODO>`
@ -32,16 +32,14 @@ Based on the [`run_xtreme_s.py`](https://github.com/huggingface/transformers/blo
This script can fine-tune any of the pretrained speech models on the [hub](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition) on the [XTREME-S dataset](https://huggingface.co/datasets/google/xtreme_s) tasks. This script can fine-tune any of the pretrained speech models on the [hub](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition) on the [XTREME-S dataset](https://huggingface.co/datasets/google/xtreme_s) tasks.
XTREME-S is made up of 7 different task-specific subsets. Here is how to run the script on each of them: XTREME-S is made up of 7 different tasks. Here is how to run the script on each of them:
```bash ```bash
export TASK_NAME=mls.all export TASK_NAME=mls.all
python run_xtreme_s.py \ python run_xtreme_s.py \
--model_name_or_path="facebook/wav2vec2-xls-r-300m" \ --model_name_or_path="facebook/wav2vec2-xls-r-300m" \
--dataset_name="google/xtreme_s" \ --task="${TASK_NAME}" \
--dataset_config_name="${TASK_NAME}" \
--eval_split_name="validation" \
--output_dir="xtreme_s_xlsr_${TASK_NAME}" \ --output_dir="xtreme_s_xlsr_${TASK_NAME}" \
--num_train_epochs=100 \ --num_train_epochs=100 \
--per_device_train_batch_size=32 \ --per_device_train_batch_size=32 \
@ -49,16 +47,16 @@ python run_xtreme_s.py \
--target_column_name="transcription" \ --target_column_name="transcription" \
--save_steps=500 \ --save_steps=500 \
--eval_steps=500 \ --eval_steps=500 \
--freeze_feature_encoder \
--gradient_checkpointing \ --gradient_checkpointing \
--fp16 \ --fp16 \
--group_by_length \ --group_by_length \
--do_train \ --do_train \
--do_eval \ --do_eval \
--do_predict \
--push_to_hub --push_to_hub
``` ```
where `TASK_NAME` can be one of: `mls.all, voxpopuli, covost2.all, fleurs.all, minds14.all`. where `TASK_NAME` can be one of: `mls, voxpopuli, covost2, fleurs-asr, fleurs-lang_id, minds14`.
We get the following results on the test set of the benchmark's datasets. We get the following results on the test set of the benchmark's datasets.
The corresponding training commands for each dataset are given in the sections below: The corresponding training commands for each dataset are given in the sections below:
@ -109,6 +107,7 @@ python -m torch.distributed.launch \
--group_by_length \ --group_by_length \
--do_train \ --do_train \
--do_eval \ --do_eval \
--do_predict \
--metric_for_best_model="wer" \ --metric_for_best_model="wer" \
--greater_is_better=False \ --greater_is_better=False \
--load_best_model_at_end \ --load_best_model_at_end \
@ -152,6 +151,7 @@ python -m torch.distributed.launch \
--group_by_length \ --group_by_length \
--do_train \ --do_train \
--do_eval \ --do_eval \
--do_predict \
--metric_for_best_model="f1" \ --metric_for_best_model="f1" \
--greater_is_better=True \ --greater_is_better=True \
--load_best_model_at_end \ --load_best_model_at_end \

View File

@ -15,7 +15,6 @@
""" Fine-tuning a 🤗 Transformers pretrained speech model on the XTREME-S benchmark tasks""" """ Fine-tuning a 🤗 Transformers pretrained speech model on the XTREME-S benchmark tasks"""
import functools
import json import json
import logging import logging
import os import os
@ -152,8 +151,8 @@ class DataTrainingArguments:
""" """
dataset_name: str = field( dataset_name: str = field(
default="xtreme_s", default="google/xtreme_s",
metadata={"help": "The name of the dataset to use (via the datasets library). Defaults to 'xtreme_s'"}, metadata={"help": "The name of the dataset to use (via the datasets library). Defaults to 'google/xtreme_s'"},
) )
task: str = field( task: str = field(
default=None, default=None,
@ -169,21 +168,20 @@ class DataTrainingArguments:
train_split_name: str = field( train_split_name: str = field(
default="train", default="train",
metadata={ metadata={
"help": "The name of the training data set split to use (via the datasets library). " "Defaults to 'train'" "help": "The name of the training dataset split to use (via the datasets library). Defaults to 'train'"
}, },
) )
eval_split_name: str = field( eval_split_name: str = field(
default="validation", default="validation",
metadata={ metadata={
"help": "The name of the evaluation data set split to use (via the datasets library). " "help": "The name of the evaluation dataset split to use (via the datasets library). "
"Defaults to 'validation'" "Defaults to 'validation'"
}, },
) )
predict_split_name: str = field( predict_split_name: str = field(
default="test", default="test",
metadata={ metadata={
"help": "The name of the prediction data set split to use (via the datasets library). " "help": "The name of the prediction dataset split to use (via the datasets library). " "Defaults to 'test'"
"Defaults to 'test'"
}, },
) )
audio_column_name: str = field( audio_column_name: str = field(
@ -191,10 +189,10 @@ class DataTrainingArguments:
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
) )
target_column_name: str = field( target_column_name: str = field(
default="transcription", default=None,
metadata={ metadata={
"help": "The name of the dataset column containing the target data " "help": "The name of the dataset column containing the target data "
"(transcription/translation/label). Defaults to 'transcription'" "(transcription/translation/label). If None, the name will be inferred from the task. Defaults to None."
}, },
) )
overwrite_cache: bool = field( overwrite_cache: bool = field(
@ -348,8 +346,10 @@ def create_vocabulary_from_data(
) )
# take union of all unique characters in each dataset # take union of all unique characters in each dataset
vocab_set = functools.reduce( vocab_set = (
lambda vocab_1, vocab_2: set(vocab_1["vocab"][0]) | set(vocab_2["vocab"][0]), vocabs.values() (set(vocabs["train"]["vocab"][0]) if "train" in vocabs else set())
| (set(vocabs["eval"]["vocab"][0]) if "eval" in vocabs else set())
| (set(vocabs["predict"]["vocab"][0]) if "predict" in vocabs else set())
) )
vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))} vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))}
@ -434,7 +434,10 @@ def main():
" for multi-lingual fine-tuning." " for multi-lingual fine-tuning."
) )
target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name] if data_args.target_column_name is None:
target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name]
else:
target_column_name = data_args.target_column_name
# here we differentiate between tasks with text as the target and classification tasks # here we differentiate between tasks with text as the target and classification tasks
is_text_target = target_column_name in ("transcription", "translation") is_text_target = target_column_name in ("transcription", "translation")
@ -457,9 +460,9 @@ def main():
f"{', '.join(raw_datasets['train'].column_names)}." f"{', '.join(raw_datasets['train'].column_names)}."
) )
if data_args.target_column_name not in raw_datasets["train"].column_names: if target_column_name not in raw_datasets["train"].column_names:
raise ValueError( raise ValueError(
f"--target_column_name {data_args.target_column_name} not found in dataset '{data_args.dataset_name}'. " f"--target_column_name {target_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--target_column_name` to the correct text column - one of " "Make sure to set `--target_column_name` to the correct text column - one of "
f"{', '.join(raw_datasets['train'].column_names)}." f"{', '.join(raw_datasets['train'].column_names)}."
) )
@ -468,7 +471,7 @@ def main():
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples)) raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
if not is_text_target: if not is_text_target:
label_list = raw_datasets["train"].features[data_args.target_column_name].names label_list = raw_datasets["train"].features[target_column_name].names
num_labels = len(label_list) num_labels = len(label_list)
if training_args.do_eval: if training_args.do_eval:
@ -684,7 +687,7 @@ def main():
if is_text_target: if is_text_target:
batch["labels"] = tokenizer(batch["target_text"], **additional_kwargs).input_ids batch["labels"] = tokenizer(batch["target_text"], **additional_kwargs).input_ids
else: else:
batch["labels"] = batch[data_args.target_column_name] batch["labels"] = batch[target_column_name]
return batch return batch
with training_args.main_process_first(desc="dataset map preprocessing"): with training_args.main_process_first(desc="dataset map preprocessing"):
@ -809,10 +812,10 @@ def main():
trainer.save_metrics("train", metrics) trainer.save_metrics("train", metrics)
trainer.save_state() trainer.save_state()
# Evaluation # Evaluation on the test set
results = {} results = {}
if training_args.do_predict: if training_args.do_predict:
logger.info("*** Predicte ***") logger.info(f"*** Evaluating on the `{data_args.predict_split_name}` set ***")
metrics = trainer.evaluate(vectorized_datasets["predict"]) metrics = trainer.evaluate(vectorized_datasets["predict"])
max_predict_samples = ( max_predict_samples = (
data_args.max_predict_samples data_args.max_predict_samples
@ -831,9 +834,8 @@ def main():
"tags": [task_name, data_args.dataset_name], "tags": [task_name, data_args.dataset_name],
"dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}, Predict split: {data_args.predict_split_name}", "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}, Predict split: {data_args.predict_split_name}",
"dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}", "dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
"language": data_args.language,
} }
if "common_voice" in data_args.dataset_name:
kwargs["language"] = config_name
if training_args.push_to_hub: if training_args.push_to_hub:
trainer.push_to_hub(**kwargs) trainer.push_to_hub(**kwargs)