mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Xtreme-S] fix some namings (#16183)
This commit is contained in:
parent
99fd3eb4a5
commit
c2dc89be62
@ -81,9 +81,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/
|
||||
python -m torch.distributed.launch \
|
||||
--nproc_per_node=8 \
|
||||
run_xtreme_s.py \
|
||||
--task="mls" \
|
||||
--language="all" \
|
||||
--model_name_or_path="facebook/wav2vec2-xls-r-300m" \
|
||||
--dataset_name="google/xtreme_s" \
|
||||
--dataset_config_name="mls.all" \
|
||||
--eval_split_name="test" \
|
||||
--output_dir="xtreme_s_xlsr_300m_mls" \
|
||||
--overwrite_output_dir \
|
||||
@ -94,7 +94,6 @@ python -m torch.distributed.launch \
|
||||
--learning_rate="3e-4" \
|
||||
--warmup_steps=3000 \
|
||||
--evaluation_strategy="steps" \
|
||||
--target_column_name="transcription" \
|
||||
--max_duration_in_seconds=20 \
|
||||
--save_steps=500 \
|
||||
--eval_steps=500 \
|
||||
@ -126,10 +125,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/
|
||||
python -m torch.distributed.launch \
|
||||
--nproc_per_node=2 \
|
||||
run_xtreme_s.py \
|
||||
--task="minds14" \
|
||||
--language="all" \
|
||||
--model_name_or_path="facebook/wav2vec2-xls-r-300m" \
|
||||
--dataset_name="google/xtreme_s" \
|
||||
--dataset_config_name="minds14.all" \
|
||||
--eval_split_name="test" \
|
||||
--output_dir="xtreme_s_xlsr_300m_minds14" \
|
||||
--overwrite_output_dir \
|
||||
--num_train_epochs=50 \
|
||||
@ -139,7 +137,6 @@ python -m torch.distributed.launch \
|
||||
--learning_rate="3e-4" \
|
||||
--warmup_steps=1500 \
|
||||
--evaluation_strategy="steps" \
|
||||
--target_column_name="intent_class" \
|
||||
--max_duration_in_seconds=30 \
|
||||
--save_steps=200 \
|
||||
--eval_steps=200 \
|
@ -62,6 +62,17 @@ def list_field(default=None, metadata=None):
|
||||
return field(default_factory=lambda: default, metadata=metadata)
|
||||
|
||||
|
||||
TASK_TO_TARGET_COLUMN_NAME = {
|
||||
"fleurs-asr": "transcription",
|
||||
"fleurs-lang_id": "lang_id",
|
||||
"mls": "transcription",
|
||||
"voxpopuli": "transcription",
|
||||
"covost2": "translation",
|
||||
"minds14": "intent_class",
|
||||
"babel": "transcription",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
@ -144,8 +155,16 @@ class DataTrainingArguments:
|
||||
default="xtreme_s",
|
||||
metadata={"help": "The name of the dataset to use (via the datasets library). Defaults to 'xtreme_s'"},
|
||||
)
|
||||
dataset_config_name: str = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
task: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The task name of the benchmark to use (via the datasets library). Should be on of: "
|
||||
"'fleurs-asr', 'mls', 'voxpopuli', 'covost2', 'minds14', 'fleurs-lang_id', 'babel'."
|
||||
},
|
||||
)
|
||||
language: str = field(
|
||||
default="all",
|
||||
metadata={"help": "The language id as defined in the datasets config name or `all` for all languages."},
|
||||
)
|
||||
train_split_name: str = field(
|
||||
default="train",
|
||||
@ -160,6 +179,13 @@ class DataTrainingArguments:
|
||||
"Defaults to 'validation'"
|
||||
},
|
||||
)
|
||||
predict_split_name: str = field(
|
||||
default="test",
|
||||
metadata={
|
||||
"help": "The name of the prediction data set split to use (via the datasets library). "
|
||||
"Defaults to 'test'"
|
||||
},
|
||||
)
|
||||
audio_column_name: str = field(
|
||||
default="audio",
|
||||
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
|
||||
@ -192,6 +218,13 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
chars_to_ignore: Optional[List[str]] = list_field(
|
||||
default=', ? . ! - ; : " “ % ‘ ” <20>'.split(" "),
|
||||
metadata={"help": "A list of characters to remove from the transcripts."},
|
||||
@ -387,22 +420,31 @@ def main():
|
||||
|
||||
# 1. First, let's load the dataset
|
||||
raw_datasets = DatasetDict()
|
||||
if data_args.dataset_config_name is None:
|
||||
task_name = data_args.task
|
||||
lang_id = data_args.language
|
||||
|
||||
if task_name is None:
|
||||
raise ValueError(
|
||||
"Set --dataset_config_name should be set to '<xtreme_s_subset>.<language(s)>' "
|
||||
"(e.g. 'mls.pl', 'covost2.en.tr', 'minds14.fr-FR') "
|
||||
"or '<xtreme_s_subset>.all' for multi-lingual fine-tuning."
|
||||
"Set --task should be set to '<xtreme_s_task>' " "(e.g. 'fleurs-asr', 'mls', 'covost2', 'minds14') "
|
||||
)
|
||||
if lang_id is None:
|
||||
raise ValueError(
|
||||
"Set --language should be set to the language id of the sub dataset "
|
||||
"config to be used (e.g. 'pl', 'en.tr', 'fr-FR') or 'all'"
|
||||
" for multi-lingual fine-tuning."
|
||||
)
|
||||
|
||||
task_name = data_args.dataset_config_name.split(".")[0]
|
||||
target_column_name = data_args.target_column_name
|
||||
target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name]
|
||||
|
||||
# here we differentiate between tasks with text as the target and classification tasks
|
||||
is_text_target = target_column_name in ("transcription", "translation")
|
||||
|
||||
config_name = ".".join([task_name.split("-")[0], lang_id])
|
||||
|
||||
if training_args.do_train:
|
||||
raw_datasets["train"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
config_name,
|
||||
split=data_args.train_split_name,
|
||||
use_auth_token=data_args.use_auth_token,
|
||||
cache_dir=model_args.cache_dir,
|
||||
@ -432,7 +474,7 @@ def main():
|
||||
if training_args.do_eval:
|
||||
raw_datasets["eval"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
config_name,
|
||||
split=data_args.eval_split_name,
|
||||
use_auth_token=data_args.use_auth_token,
|
||||
cache_dir=model_args.cache_dir,
|
||||
@ -441,6 +483,18 @@ def main():
|
||||
if data_args.max_eval_samples is not None:
|
||||
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
||||
|
||||
if training_args.do_predict:
|
||||
raw_datasets["predict"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
config_name,
|
||||
split=data_args.predict_split_name,
|
||||
use_auth_token=data_args.use_auth_token,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
if data_args.max_predict_samples is not None:
|
||||
raw_datasets["predict"] = raw_datasets["predict"].select(range(data_args.max_predict_samples))
|
||||
|
||||
# 2. We remove some special characters from the datasets
|
||||
# that make training complicated and do not help in transcribing the speech
|
||||
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
|
||||
@ -757,24 +811,25 @@ def main():
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate()
|
||||
max_eval_samples = (
|
||||
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predicte ***")
|
||||
metrics = trainer.evaluate(vectorized_datasets["predict"])
|
||||
max_predict_samples = (
|
||||
data_args.max_predict_samples
|
||||
if data_args.max_predict_samples is not None
|
||||
else len(vectorized_datasets["predict"])
|
||||
)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"]))
|
||||
metrics["predict_samples"] = min(max_predict_samples, len(vectorized_datasets["predict"]))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
trainer.log_metrics("predict", metrics)
|
||||
trainer.save_metrics("predict", metrics)
|
||||
|
||||
# Write model card and (optionally) push to hub
|
||||
config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
|
||||
kwargs = {
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"tasks": "speech-recognition",
|
||||
"tags": ["automatic-speech-recognition", data_args.dataset_name],
|
||||
"dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}",
|
||||
"tasks": task_name,
|
||||
"tags": [task_name, data_args.dataset_name],
|
||||
"dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}, Predict split: {data_args.predict_split_name}",
|
||||
"dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
|
||||
}
|
||||
if "common_voice" in data_args.dataset_name:
|
Loading…
Reference in New Issue
Block a user