From 9490d668d2f59ad2e7a4db3dc7ed2f9684af369c Mon Sep 17 00:00:00 2001 From: Taha ValizadehAslani <47432410+TahaAslani@users.noreply.github.com> Date: Mon, 28 Jun 2021 06:49:22 -0500 Subject: [PATCH] Update run_mlm.py (#12344) Before the code could not be used for validation only because of this line: extension = data_args.train_file.split(".")[-1] was assuming that extension must be extracted from the training dataset. This line would run regardless of the training or validation options of the user. This would lead to an error if the user only wants to run an evaluation only and does not want to do train (because the training file does not exist). I modified it to extract extension from the training file if the user wants to do train and extract it from the validation file if the user wants to run eval. This way the code can be used for both training and validation separately. --- examples/pytorch/language-modeling/run_mlm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index 84bc59186b7..57bd1d891fd 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -278,9 +278,10 @@ def main(): data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file + extension = data_args.train_file.split(".")[-1] if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file - extension = data_args.train_file.split(".")[-1] + extension = data_args.validation_file.split(".")[-1] if extension == "txt": extension = "text" raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)