Update run_glue for do_predict with local test data (#9442) (#9486)

* Update run_glue for do_predict with local test data (#9442)

* Update run_glue (#9442): fix comments ('files' to 'a file')

* Update run_glue (#9442): reflect the code review

* Update run_glue (#9442): auto format

* Update run_glue (#9442): reflect the code review
This commit is contained in:
Yusuke Mori 2021-01-13 21:48:35 +09:00 committed by GitHub
parent 0c9f01a8e5
commit eabad8fd9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -93,6 +93,7 @@ class DataTrainingArguments:
validation_file: Optional[str] = field(
default=None, metadata={"help": "A csv or a json file containing the validation data."}
)
test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
def __post_init__(self):
if self.task_name is not None:
@ -102,10 +103,12 @@ class DataTrainingArguments:
elif self.train_file is None or self.validation_file is None:
raise ValueError("Need either a GLUE task or a training/validation file.")
else:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
train_extension = self.train_file.split(".")[-1]
assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
validation_extension = self.validation_file.split(".")[-1]
assert (
validation_extension == train_extension
), "`validation_file` should have the same extension (csv or json) as `train_file`."
@dataclass
@ -205,16 +208,33 @@ def main():
if data_args.task_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset("glue", data_args.task_name)
elif data_args.train_file.endswith(".csv"):
# Loading a dataset from local csv files
datasets = load_dataset(
"csv", data_files={"train": data_args.train_file, "validation": data_args.validation_file}
)
else:
# Loading a dataset from local json files
datasets = load_dataset(
"json", data_files={"train": data_args.train_file, "validation": data_args.validation_file}
)
# Loading a dataset from your local files.
# CSV/JSON training and evaluation files are needed.
data_files = {"train": data_args.train_file, "validation": data_args.validation_file}
# Get the test dataset: you can provide your own CSV/JSON test file (see below)
# when you use `do_predict` without specifying a GLUE benchmark task.
if training_args.do_predict:
if data_args.test_file is not None:
train_extension = data_args.train_file.split(".")[-1]
test_extension = data_args.test_file.split(".")[-1]
assert (
test_extension == train_extension
), "`test_file` should have the same extension (csv or json) as `train_file`."
data_files["test"] = data_args.test_file
else:
raise ValueError("Need either a GLUE task or a test file for `do_predict`.")
for key in data_files.keys():
logger.info(f"load a local file for {key}: {data_files[key]}")
if data_args.train_file.endswith(".csv"):
# Loading a dataset from local csv files
datasets = load_dataset("csv", data_files=data_files)
else:
# Loading a dataset from local json files
datasets = load_dataset("json", data_files=data_files)
# See more about loading any type of standard or custom dataset at
# https://huggingface.co/docs/datasets/loading_datasets.html.
@ -323,7 +343,7 @@ def main():
train_dataset = datasets["train"]
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
if data_args.task_name is not None:
if data_args.task_name is not None or data_args.test_file is not None:
test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
# Log a few random samples from the training set: