[run_clm.py] fix getting extention

This commit is contained in:
Suraj Patil 2021-02-03 20:14:42 +05:30 committed by GitHub
parent 5442a11f5f
commit bca0dd5ee3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -227,7 +227,11 @@ def main():
data_files["train"] = data_args.train_file
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.train_file.split(".")[-1]
extension = (
data_args.train_file.split(".")[-1]
if data_args.train_file is not None
else data_args.validation_file.split(".")[-1]
)
if extension == "txt":
extension = "text"
datasets = load_dataset(extension, data_files=data_files)