mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Minor fixes to XTREME-S (#16193)
* Minor fixes * Fix vocab union * Update examples/research_projects/xtreme-s/README.md Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update README * unused import Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
8cc925a241
commit
d35e0c6247
@ -20,7 +20,7 @@ limitations under the License.
|
||||
|
||||
The Cross-lingual TRansfer Evaluation of Multilingual Encoders for Speech (XTREME-S) benchmark is a benchmark designed to evaluate speech representations across languages, tasks, domains and data regimes. It covers XX typologically diverse languages and seven downstream tasks grouped in four families: speech recognition, translation, classification and retrieval.
|
||||
|
||||
XTREME-S covers speech recognition with BABEL, Multilingual LibriSpeech (MLS) and VoxPopuli, speech translation with CoVoST-2, speech classification with LangID (FLoRes) and intent classification (MInds-14) and finally speech retrieval with speech-speech translation data mining (bi-speech retrieval). Each of the tasks covers a subset of the 40 languages included in XTREME-S (shown here with their ISO 639-1 codes): ar, as, ca, cs, cy, da, de, en, en, en, en, es, et, fa, fi, fr, hr, hu, id, it, ja, ka, ko, lo, lt, lv, mn, nl, pl, pt, ro, ru, sk, sl, sv, sw, ta, tl, tr and zh.
|
||||
XTREME-S covers speech recognition with Fleurs, Multilingual LibriSpeech (MLS) and VoxPopuli, speech translation with CoVoST-2, speech classification with LangID (Fleurs) and intent classification (MInds-14) and finally speech(-text) retrieval with Fleurs. Each of the tasks covers a subset of the 102 languages included in XTREME-S (shown here with their ISO 3166-1 codes): afr, amh, ara, asm, ast, azj, bel, ben, bos, cat, ceb, zho_simpl, zho_trad, ces, cym, dan, deu, ell, eng, spa, est, fas, ful, fin, tgl, fra, gle, glg, guj, hau, heb, hin, hrv, hun, hye, ind, ibo, isl, ita, jpn, jav, kat, kam, kea, kaz, khm, kan, kor, ckb, kir, ltz, lug, lin, lao, lit, luo, lav, mri, mkd, mal, mon, mar, msa, mlt, mya, nob, npi, nld, nso, nya, oci, orm, ory, pan, pol, pus, por, ron, rus, bul, snd, slk, slv, sna, som, srp, swe, swh, tam, tel, tgk, tha, tur, ukr, umb, urd, uzb, vie, wol, xho, yor and zul.
|
||||
|
||||
Paper: `<TODO>`
|
||||
|
||||
@ -32,16 +32,14 @@ Based on the [`run_xtreme_s.py`](https://github.com/huggingface/transformers/blo
|
||||
|
||||
This script can fine-tune any of the pretrained speech models on the [hub](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition) on the [XTREME-S dataset](https://huggingface.co/datasets/google/xtreme_s) tasks.
|
||||
|
||||
XTREME-S is made up of 7 different task-specific subsets. Here is how to run the script on each of them:
|
||||
XTREME-S is made up of 7 different tasks. Here is how to run the script on each of them:
|
||||
|
||||
```bash
|
||||
export TASK_NAME=mls.all
|
||||
|
||||
python run_xtreme_s.py \
|
||||
--model_name_or_path="facebook/wav2vec2-xls-r-300m" \
|
||||
--dataset_name="google/xtreme_s" \
|
||||
--dataset_config_name="${TASK_NAME}" \
|
||||
--eval_split_name="validation" \
|
||||
--task="${TASK_NAME}" \
|
||||
--output_dir="xtreme_s_xlsr_${TASK_NAME}" \
|
||||
--num_train_epochs=100 \
|
||||
--per_device_train_batch_size=32 \
|
||||
@ -49,16 +47,16 @@ python run_xtreme_s.py \
|
||||
--target_column_name="transcription" \
|
||||
--save_steps=500 \
|
||||
--eval_steps=500 \
|
||||
--freeze_feature_encoder \
|
||||
--gradient_checkpointing \
|
||||
--fp16 \
|
||||
--group_by_length \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_predict \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
where `TASK_NAME` can be one of: `mls.all, voxpopuli, covost2.all, fleurs.all, minds14.all`.
|
||||
where `TASK_NAME` can be one of: `mls, voxpopuli, covost2, fleurs-asr, fleurs-lang_id, minds14`.
|
||||
|
||||
We get the following results on the test set of the benchmark's datasets.
|
||||
The corresponding training commands for each dataset are given in the sections below:
|
||||
@ -109,6 +107,7 @@ python -m torch.distributed.launch \
|
||||
--group_by_length \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_predict \
|
||||
--metric_for_best_model="wer" \
|
||||
--greater_is_better=False \
|
||||
--load_best_model_at_end \
|
||||
@ -152,6 +151,7 @@ python -m torch.distributed.launch \
|
||||
--group_by_length \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_predict \
|
||||
--metric_for_best_model="f1" \
|
||||
--greater_is_better=True \
|
||||
--load_best_model_at_end \
|
||||
|
@ -15,7 +15,6 @@
|
||||
|
||||
""" Fine-tuning a 🤗 Transformers pretrained speech model on the XTREME-S benchmark tasks"""
|
||||
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@ -152,8 +151,8 @@ class DataTrainingArguments:
|
||||
"""
|
||||
|
||||
dataset_name: str = field(
|
||||
default="xtreme_s",
|
||||
metadata={"help": "The name of the dataset to use (via the datasets library). Defaults to 'xtreme_s'"},
|
||||
default="google/xtreme_s",
|
||||
metadata={"help": "The name of the dataset to use (via the datasets library). Defaults to 'google/xtreme_s'"},
|
||||
)
|
||||
task: str = field(
|
||||
default=None,
|
||||
@ -169,21 +168,20 @@ class DataTrainingArguments:
|
||||
train_split_name: str = field(
|
||||
default="train",
|
||||
metadata={
|
||||
"help": "The name of the training data set split to use (via the datasets library). " "Defaults to 'train'"
|
||||
"help": "The name of the training dataset split to use (via the datasets library). Defaults to 'train'"
|
||||
},
|
||||
)
|
||||
eval_split_name: str = field(
|
||||
default="validation",
|
||||
metadata={
|
||||
"help": "The name of the evaluation data set split to use (via the datasets library). "
|
||||
"help": "The name of the evaluation dataset split to use (via the datasets library). "
|
||||
"Defaults to 'validation'"
|
||||
},
|
||||
)
|
||||
predict_split_name: str = field(
|
||||
default="test",
|
||||
metadata={
|
||||
"help": "The name of the prediction data set split to use (via the datasets library). "
|
||||
"Defaults to 'test'"
|
||||
"help": "The name of the prediction dataset split to use (via the datasets library). " "Defaults to 'test'"
|
||||
},
|
||||
)
|
||||
audio_column_name: str = field(
|
||||
@ -191,10 +189,10 @@ class DataTrainingArguments:
|
||||
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
|
||||
)
|
||||
target_column_name: str = field(
|
||||
default="transcription",
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The name of the dataset column containing the target data "
|
||||
"(transcription/translation/label). Defaults to 'transcription'"
|
||||
"(transcription/translation/label). If None, the name will be inferred from the task. Defaults to None."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
@ -348,8 +346,10 @@ def create_vocabulary_from_data(
|
||||
)
|
||||
|
||||
# take union of all unique characters in each dataset
|
||||
vocab_set = functools.reduce(
|
||||
lambda vocab_1, vocab_2: set(vocab_1["vocab"][0]) | set(vocab_2["vocab"][0]), vocabs.values()
|
||||
vocab_set = (
|
||||
(set(vocabs["train"]["vocab"][0]) if "train" in vocabs else set())
|
||||
| (set(vocabs["eval"]["vocab"][0]) if "eval" in vocabs else set())
|
||||
| (set(vocabs["predict"]["vocab"][0]) if "predict" in vocabs else set())
|
||||
)
|
||||
|
||||
vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))}
|
||||
@ -434,7 +434,10 @@ def main():
|
||||
" for multi-lingual fine-tuning."
|
||||
)
|
||||
|
||||
target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name]
|
||||
if data_args.target_column_name is None:
|
||||
target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name]
|
||||
else:
|
||||
target_column_name = data_args.target_column_name
|
||||
|
||||
# here we differentiate between tasks with text as the target and classification tasks
|
||||
is_text_target = target_column_name in ("transcription", "translation")
|
||||
@ -457,9 +460,9 @@ def main():
|
||||
f"{', '.join(raw_datasets['train'].column_names)}."
|
||||
)
|
||||
|
||||
if data_args.target_column_name not in raw_datasets["train"].column_names:
|
||||
if target_column_name not in raw_datasets["train"].column_names:
|
||||
raise ValueError(
|
||||
f"--target_column_name {data_args.target_column_name} not found in dataset '{data_args.dataset_name}'. "
|
||||
f"--target_column_name {target_column_name} not found in dataset '{data_args.dataset_name}'. "
|
||||
"Make sure to set `--target_column_name` to the correct text column - one of "
|
||||
f"{', '.join(raw_datasets['train'].column_names)}."
|
||||
)
|
||||
@ -468,7 +471,7 @@ def main():
|
||||
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
||||
|
||||
if not is_text_target:
|
||||
label_list = raw_datasets["train"].features[data_args.target_column_name].names
|
||||
label_list = raw_datasets["train"].features[target_column_name].names
|
||||
num_labels = len(label_list)
|
||||
|
||||
if training_args.do_eval:
|
||||
@ -684,7 +687,7 @@ def main():
|
||||
if is_text_target:
|
||||
batch["labels"] = tokenizer(batch["target_text"], **additional_kwargs).input_ids
|
||||
else:
|
||||
batch["labels"] = batch[data_args.target_column_name]
|
||||
batch["labels"] = batch[target_column_name]
|
||||
return batch
|
||||
|
||||
with training_args.main_process_first(desc="dataset map preprocessing"):
|
||||
@ -809,10 +812,10 @@ def main():
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluation
|
||||
# Evaluation on the test set
|
||||
results = {}
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predicte ***")
|
||||
logger.info(f"*** Evaluating on the `{data_args.predict_split_name}` set ***")
|
||||
metrics = trainer.evaluate(vectorized_datasets["predict"])
|
||||
max_predict_samples = (
|
||||
data_args.max_predict_samples
|
||||
@ -831,9 +834,8 @@ def main():
|
||||
"tags": [task_name, data_args.dataset_name],
|
||||
"dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}, Predict split: {data_args.predict_split_name}",
|
||||
"dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
|
||||
"language": data_args.language,
|
||||
}
|
||||
if "common_voice" in data_args.dataset_name:
|
||||
kwargs["language"] = config_name
|
||||
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user