Minor fixes to XTREME-S (#16193)

* Minor fixes

* Fix vocab union

* Update examples/research_projects/xtreme-s/README.md

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update README

* unused import

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Anton Lozhkov 2022-03-16 17:23:00 +04:00 committed by GitHub
parent 8cc925a241
commit d35e0c6247
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 27 deletions

View File

@ -20,7 +20,7 @@ limitations under the License.
The Cross-lingual TRansfer Evaluation of Multilingual Encoders for Speech (XTREME-S) benchmark is a benchmark designed to evaluate speech representations across languages, tasks, domains and data regimes. It covers XX typologically diverse languages and seven downstream tasks grouped in four families: speech recognition, translation, classification and retrieval.
XTREME-S covers speech recognition with BABEL, Multilingual LibriSpeech (MLS) and VoxPopuli, speech translation with CoVoST-2, speech classification with LangID (FLoRes) and intent classification (MInds-14) and finally speech retrieval with speech-speech translation data mining (bi-speech retrieval). Each of the tasks covers a subset of the 40 languages included in XTREME-S (shown here with their ISO 639-1 codes): ar, as, ca, cs, cy, da, de, en, en, en, en, es, et, fa, fi, fr, hr, hu, id, it, ja, ka, ko, lo, lt, lv, mn, nl, pl, pt, ro, ru, sk, sl, sv, sw, ta, tl, tr and zh.
XTREME-S covers speech recognition with Fleurs, Multilingual LibriSpeech (MLS) and VoxPopuli, speech translation with CoVoST-2, speech classification with LangID (Fleurs) and intent classification (MInds-14) and finally speech(-text) retrieval with Fleurs. Each of the tasks covers a subset of the 102 languages included in XTREME-S (shown here with their ISO 3166-1 codes): afr, amh, ara, asm, ast, azj, bel, ben, bos, cat, ceb, zho_simpl, zho_trad, ces, cym, dan, deu, ell, eng, spa, est, fas, ful, fin, tgl, fra, gle, glg, guj, hau, heb, hin, hrv, hun, hye, ind, ibo, isl, ita, jpn, jav, kat, kam, kea, kaz, khm, kan, kor, ckb, kir, ltz, lug, lin, lao, lit, luo, lav, mri, mkd, mal, mon, mar, msa, mlt, mya, nob, npi, nld, nso, nya, oci, orm, ory, pan, pol, pus, por, ron, rus, bul, snd, slk, slv, sna, som, srp, swe, swh, tam, tel, tgk, tha, tur, ukr, umb, urd, uzb, vie, wol, xho, yor and zul.
Paper: `<TODO>`
@ -32,16 +32,14 @@ Based on the [`run_xtreme_s.py`](https://github.com/huggingface/transformers/blo
This script can fine-tune any of the pretrained speech models on the [hub](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition) on the [XTREME-S dataset](https://huggingface.co/datasets/google/xtreme_s) tasks.
XTREME-S is made up of 7 different task-specific subsets. Here is how to run the script on each of them:
XTREME-S is made up of 7 different tasks. Here is how to run the script on each of them:
```bash
export TASK_NAME=mls.all
python run_xtreme_s.py \
--model_name_or_path="facebook/wav2vec2-xls-r-300m" \
--dataset_name="google/xtreme_s" \
--dataset_config_name="${TASK_NAME}" \
--eval_split_name="validation" \
--task="${TASK_NAME}" \
--output_dir="xtreme_s_xlsr_${TASK_NAME}" \
--num_train_epochs=100 \
--per_device_train_batch_size=32 \
@ -49,16 +47,16 @@ python run_xtreme_s.py \
--target_column_name="transcription" \
--save_steps=500 \
--eval_steps=500 \
--freeze_feature_encoder \
--gradient_checkpointing \
--fp16 \
--group_by_length \
--do_train \
--do_eval \
--do_predict \
--push_to_hub
```
where `TASK_NAME` can be one of: `mls.all, voxpopuli, covost2.all, fleurs.all, minds14.all`.
where `TASK_NAME` can be one of: `mls, voxpopuli, covost2, fleurs-asr, fleurs-lang_id, minds14`.
We get the following results on the test set of the benchmark's datasets.
The corresponding training commands for each dataset are given in the sections below:
@ -109,6 +107,7 @@ python -m torch.distributed.launch \
--group_by_length \
--do_train \
--do_eval \
--do_predict \
--metric_for_best_model="wer" \
--greater_is_better=False \
--load_best_model_at_end \
@ -152,6 +151,7 @@ python -m torch.distributed.launch \
--group_by_length \
--do_train \
--do_eval \
--do_predict \
--metric_for_best_model="f1" \
--greater_is_better=True \
--load_best_model_at_end \

View File

@ -15,7 +15,6 @@
""" Fine-tuning a 🤗 Transformers pretrained speech model on the XTREME-S benchmark tasks"""
import functools
import json
import logging
import os
@ -152,8 +151,8 @@ class DataTrainingArguments:
"""
dataset_name: str = field(
default="xtreme_s",
metadata={"help": "The name of the dataset to use (via the datasets library). Defaults to 'xtreme_s'"},
default="google/xtreme_s",
metadata={"help": "The name of the dataset to use (via the datasets library). Defaults to 'google/xtreme_s'"},
)
task: str = field(
default=None,
@ -169,21 +168,20 @@ class DataTrainingArguments:
train_split_name: str = field(
default="train",
metadata={
"help": "The name of the training data set split to use (via the datasets library). " "Defaults to 'train'"
"help": "The name of the training dataset split to use (via the datasets library). Defaults to 'train'"
},
)
eval_split_name: str = field(
default="validation",
metadata={
"help": "The name of the evaluation data set split to use (via the datasets library). "
"help": "The name of the evaluation dataset split to use (via the datasets library). "
"Defaults to 'validation'"
},
)
predict_split_name: str = field(
default="test",
metadata={
"help": "The name of the prediction data set split to use (via the datasets library). "
"Defaults to 'test'"
"help": "The name of the prediction dataset split to use (via the datasets library). " "Defaults to 'test'"
},
)
audio_column_name: str = field(
@ -191,10 +189,10 @@ class DataTrainingArguments:
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
)
target_column_name: str = field(
default="transcription",
default=None,
metadata={
"help": "The name of the dataset column containing the target data "
"(transcription/translation/label). Defaults to 'transcription'"
"(transcription/translation/label). If None, the name will be inferred from the task. Defaults to None."
},
)
overwrite_cache: bool = field(
@ -348,8 +346,10 @@ def create_vocabulary_from_data(
)
# take union of all unique characters in each dataset
vocab_set = functools.reduce(
lambda vocab_1, vocab_2: set(vocab_1["vocab"][0]) | set(vocab_2["vocab"][0]), vocabs.values()
vocab_set = (
(set(vocabs["train"]["vocab"][0]) if "train" in vocabs else set())
| (set(vocabs["eval"]["vocab"][0]) if "eval" in vocabs else set())
| (set(vocabs["predict"]["vocab"][0]) if "predict" in vocabs else set())
)
vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))}
@ -434,7 +434,10 @@ def main():
" for multi-lingual fine-tuning."
)
target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name]
if data_args.target_column_name is None:
target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name]
else:
target_column_name = data_args.target_column_name
# here we differentiate between tasks with text as the target and classification tasks
is_text_target = target_column_name in ("transcription", "translation")
@ -457,9 +460,9 @@ def main():
f"{', '.join(raw_datasets['train'].column_names)}."
)
if data_args.target_column_name not in raw_datasets["train"].column_names:
if target_column_name not in raw_datasets["train"].column_names:
raise ValueError(
f"--target_column_name {data_args.target_column_name} not found in dataset '{data_args.dataset_name}'. "
f"--target_column_name {target_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--target_column_name` to the correct text column - one of "
f"{', '.join(raw_datasets['train'].column_names)}."
)
@ -468,7 +471,7 @@ def main():
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
if not is_text_target:
label_list = raw_datasets["train"].features[data_args.target_column_name].names
label_list = raw_datasets["train"].features[target_column_name].names
num_labels = len(label_list)
if training_args.do_eval:
@ -684,7 +687,7 @@ def main():
if is_text_target:
batch["labels"] = tokenizer(batch["target_text"], **additional_kwargs).input_ids
else:
batch["labels"] = batch[data_args.target_column_name]
batch["labels"] = batch[target_column_name]
return batch
with training_args.main_process_first(desc="dataset map preprocessing"):
@ -809,10 +812,10 @@ def main():
trainer.save_metrics("train", metrics)
trainer.save_state()
# Evaluation
# Evaluation on the test set
results = {}
if training_args.do_predict:
logger.info("*** Predicte ***")
logger.info(f"*** Evaluating on the `{data_args.predict_split_name}` set ***")
metrics = trainer.evaluate(vectorized_datasets["predict"])
max_predict_samples = (
data_args.max_predict_samples
@ -831,9 +834,8 @@ def main():
"tags": [task_name, data_args.dataset_name],
"dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}, Predict split: {data_args.predict_split_name}",
"dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
"language": data_args.language,
}
if "common_voice" in data_args.dataset_name:
kwargs["language"] = config_name
if training_args.push_to_hub:
trainer.push_to_hub(**kwargs)