diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index 94821a24f04..0c655f08b48 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -227,7 +227,11 @@ def main(): data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file - extension = data_args.train_file.split(".")[-1] + extension = ( + data_args.train_file.split(".")[-1] + if data_args.train_file is not None + else data_args.validation_file.split(".")[-1] + ) if extension == "txt": extension = "text" datasets = load_dataset(extension, data_files=data_files)