mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
* Update run_glue for do_predict with local test data (#9442) * Update run_glue (#9442): fix comments ('files' to 'a file') * Update run_glue (#9442): reflect the code review * Update run_glue (#9442): auto format * Update run_glue (#9442): reflect the code review
This commit is contained in:
parent
0c9f01a8e5
commit
eabad8fd9c
@ -93,6 +93,7 @@ class DataTrainingArguments:
|
||||
validation_file: Optional[str] = field(
|
||||
default=None, metadata={"help": "A csv or a json file containing the validation data."}
|
||||
)
|
||||
test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
|
||||
|
||||
def __post_init__(self):
|
||||
if self.task_name is not None:
|
||||
@ -102,10 +103,12 @@ class DataTrainingArguments:
|
||||
elif self.train_file is None or self.validation_file is None:
|
||||
raise ValueError("Need either a GLUE task or a training/validation file.")
|
||||
else:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
train_extension = self.train_file.split(".")[-1]
|
||||
assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
||||
validation_extension = self.validation_file.split(".")[-1]
|
||||
assert (
|
||||
validation_extension == train_extension
|
||||
), "`validation_file` should have the same extension (csv or json) as `train_file`."
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -205,16 +208,33 @@ def main():
|
||||
if data_args.task_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
datasets = load_dataset("glue", data_args.task_name)
|
||||
elif data_args.train_file.endswith(".csv"):
|
||||
# Loading a dataset from local csv files
|
||||
datasets = load_dataset(
|
||||
"csv", data_files={"train": data_args.train_file, "validation": data_args.validation_file}
|
||||
)
|
||||
else:
|
||||
# Loading a dataset from local json files
|
||||
datasets = load_dataset(
|
||||
"json", data_files={"train": data_args.train_file, "validation": data_args.validation_file}
|
||||
)
|
||||
# Loading a dataset from your local files.
|
||||
# CSV/JSON training and evaluation files are needed.
|
||||
data_files = {"train": data_args.train_file, "validation": data_args.validation_file}
|
||||
|
||||
# Get the test dataset: you can provide your own CSV/JSON test file (see below)
|
||||
# when you use `do_predict` without specifying a GLUE benchmark task.
|
||||
if training_args.do_predict:
|
||||
if data_args.test_file is not None:
|
||||
train_extension = data_args.train_file.split(".")[-1]
|
||||
test_extension = data_args.test_file.split(".")[-1]
|
||||
assert (
|
||||
test_extension == train_extension
|
||||
), "`test_file` should have the same extension (csv or json) as `train_file`."
|
||||
data_files["test"] = data_args.test_file
|
||||
else:
|
||||
raise ValueError("Need either a GLUE task or a test file for `do_predict`.")
|
||||
|
||||
for key in data_files.keys():
|
||||
logger.info(f"load a local file for {key}: {data_files[key]}")
|
||||
|
||||
if data_args.train_file.endswith(".csv"):
|
||||
# Loading a dataset from local csv files
|
||||
datasets = load_dataset("csv", data_files=data_files)
|
||||
else:
|
||||
# Loading a dataset from local json files
|
||||
datasets = load_dataset("json", data_files=data_files)
|
||||
# See more about loading any type of standard or custom dataset at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
|
||||
@ -323,7 +343,7 @@ def main():
|
||||
|
||||
train_dataset = datasets["train"]
|
||||
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
|
||||
if data_args.task_name is not None:
|
||||
if data_args.task_name is not None or data_args.test_file is not None:
|
||||
test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
|
||||
|
||||
# Log a few random samples from the training set:
|
||||
|
Loading…
Reference in New Issue
Block a user