mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Update run_mlm.py (#12344)
Before the code could not be used for validation only because of this line: extension = data_args.train_file.split(".")[-1] was assuming that extension must be extracted from the training dataset. This line would run regardless of the training or validation options of the user. This would lead to an error if the user only wants to run an evaluation only and does not want to do train (because the training file does not exist). I modified it to extract extension from the training file if the user wants to do train and extract it from the validation file if the user wants to run eval. This way the code can be used for both training and validation separately.
This commit is contained in:
parent
c7faf2ccc0
commit
9490d668d2
@ -278,9 +278,10 @@ def main():
|
|||||||
data_files = {}
|
data_files = {}
|
||||||
if data_args.train_file is not None:
|
if data_args.train_file is not None:
|
||||||
data_files["train"] = data_args.train_file
|
data_files["train"] = data_args.train_file
|
||||||
|
extension = data_args.train_file.split(".")[-1]
|
||||||
if data_args.validation_file is not None:
|
if data_args.validation_file is not None:
|
||||||
data_files["validation"] = data_args.validation_file
|
data_files["validation"] = data_args.validation_file
|
||||||
extension = data_args.train_file.split(".")[-1]
|
extension = data_args.validation_file.split(".")[-1]
|
||||||
if extension == "txt":
|
if extension == "txt":
|
||||||
extension = "text"
|
extension = "text"
|
||||||
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
||||||
|
Loading…
Reference in New Issue
Block a user