[Xtreme-S] fix some namings (#16183)

This commit is contained in:
Patrick von Platen 2022-03-16 01:21:31 +01:00 committed by GitHub
parent 99fd3eb4a5
commit c2dc89be62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 81 additions and 29 deletions

View File

@ -81,9 +81,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/
python -m torch.distributed.launch \
--nproc_per_node=8 \
run_xtreme_s.py \
--task="mls" \
--language="all" \
--model_name_or_path="facebook/wav2vec2-xls-r-300m" \
--dataset_name="google/xtreme_s" \
--dataset_config_name="mls.all" \
--eval_split_name="test" \
--output_dir="xtreme_s_xlsr_300m_mls" \
--overwrite_output_dir \
@ -94,7 +94,6 @@ python -m torch.distributed.launch \
--learning_rate="3e-4" \
--warmup_steps=3000 \
--evaluation_strategy="steps" \
--target_column_name="transcription" \
--max_duration_in_seconds=20 \
--save_steps=500 \
--eval_steps=500 \
@ -126,10 +125,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/
python -m torch.distributed.launch \
--nproc_per_node=2 \
run_xtreme_s.py \
--task="minds14" \
--language="all" \
--model_name_or_path="facebook/wav2vec2-xls-r-300m" \
--dataset_name="google/xtreme_s" \
--dataset_config_name="minds14.all" \
--eval_split_name="test" \
--output_dir="xtreme_s_xlsr_300m_minds14" \
--overwrite_output_dir \
--num_train_epochs=50 \
@ -139,7 +137,6 @@ python -m torch.distributed.launch \
--learning_rate="3e-4" \
--warmup_steps=1500 \
--evaluation_strategy="steps" \
--target_column_name="intent_class" \
--max_duration_in_seconds=30 \
--save_steps=200 \
--eval_steps=200 \

View File

@ -62,6 +62,17 @@ def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata)
TASK_TO_TARGET_COLUMN_NAME = {
"fleurs-asr": "transcription",
"fleurs-lang_id": "lang_id",
"mls": "transcription",
"voxpopuli": "transcription",
"covost2": "translation",
"minds14": "intent_class",
"babel": "transcription",
}
@dataclass
class ModelArguments:
"""
@ -144,8 +155,16 @@ class DataTrainingArguments:
default="xtreme_s",
metadata={"help": "The name of the dataset to use (via the datasets library). Defaults to 'xtreme_s'"},
)
dataset_config_name: str = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
task: str = field(
default=None,
metadata={
"help": "The task name of the benchmark to use (via the datasets library). Should be on of: "
"'fleurs-asr', 'mls', 'voxpopuli', 'covost2', 'minds14', 'fleurs-lang_id', 'babel'."
},
)
language: str = field(
default="all",
metadata={"help": "The language id as defined in the datasets config name or `all` for all languages."},
)
train_split_name: str = field(
default="train",
@ -160,6 +179,13 @@ class DataTrainingArguments:
"Defaults to 'validation'"
},
)
predict_split_name: str = field(
default="test",
metadata={
"help": "The name of the prediction data set split to use (via the datasets library). "
"Defaults to 'test'"
},
)
audio_column_name: str = field(
default="audio",
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
@ -192,6 +218,13 @@ class DataTrainingArguments:
"value if set."
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
chars_to_ignore: Optional[List[str]] = list_field(
default=', ? . ! - ; : "% <20>'.split(" "),
metadata={"help": "A list of characters to remove from the transcripts."},
@ -387,22 +420,31 @@ def main():
# 1. First, let's load the dataset
raw_datasets = DatasetDict()
if data_args.dataset_config_name is None:
task_name = data_args.task
lang_id = data_args.language
if task_name is None:
raise ValueError(
"Set --dataset_config_name should be set to '<xtreme_s_subset>.<language(s)>' "
"(e.g. 'mls.pl', 'covost2.en.tr', 'minds14.fr-FR') "
"or '<xtreme_s_subset>.all' for multi-lingual fine-tuning."
"Set --task should be set to '<xtreme_s_task>' " "(e.g. 'fleurs-asr', 'mls', 'covost2', 'minds14') "
)
if lang_id is None:
raise ValueError(
"Set --language should be set to the language id of the sub dataset "
"config to be used (e.g. 'pl', 'en.tr', 'fr-FR') or 'all'"
" for multi-lingual fine-tuning."
)
task_name = data_args.dataset_config_name.split(".")[0]
target_column_name = data_args.target_column_name
target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name]
# here we differentiate between tasks with text as the target and classification tasks
is_text_target = target_column_name in ("transcription", "translation")
config_name = ".".join([task_name.split("-")[0], lang_id])
if training_args.do_train:
raw_datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
config_name,
split=data_args.train_split_name,
use_auth_token=data_args.use_auth_token,
cache_dir=model_args.cache_dir,
@ -432,7 +474,7 @@ def main():
if training_args.do_eval:
raw_datasets["eval"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
config_name,
split=data_args.eval_split_name,
use_auth_token=data_args.use_auth_token,
cache_dir=model_args.cache_dir,
@ -441,6 +483,18 @@ def main():
if data_args.max_eval_samples is not None:
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
if training_args.do_predict:
raw_datasets["predict"] = load_dataset(
data_args.dataset_name,
config_name,
split=data_args.predict_split_name,
use_auth_token=data_args.use_auth_token,
cache_dir=model_args.cache_dir,
)
if data_args.max_predict_samples is not None:
raw_datasets["predict"] = raw_datasets["predict"].select(range(data_args.max_predict_samples))
# 2. We remove some special characters from the datasets
# that make training complicated and do not help in transcribing the speech
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
@ -757,24 +811,25 @@ def main():
# Evaluation
results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
max_eval_samples = (
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
if training_args.do_predict:
logger.info("*** Predicte ***")
metrics = trainer.evaluate(vectorized_datasets["predict"])
max_predict_samples = (
data_args.max_predict_samples
if data_args.max_predict_samples is not None
else len(vectorized_datasets["predict"])
)
metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"]))
metrics["predict_samples"] = min(max_predict_samples, len(vectorized_datasets["predict"]))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
trainer.log_metrics("predict", metrics)
trainer.save_metrics("predict", metrics)
# Write model card and (optionally) push to hub
config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
kwargs = {
"finetuned_from": model_args.model_name_or_path,
"tasks": "speech-recognition",
"tags": ["automatic-speech-recognition", data_args.dataset_name],
"dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}",
"tasks": task_name,
"tags": [task_name, data_args.dataset_name],
"dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}, Predict split: {data_args.predict_split_name}",
"dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
}
if "common_voice" in data_args.dataset_name: