mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
* Update run_glue for do_predict with local test data (#9442) * Update run_glue (#9442): fix comments ('files' to 'a file') * Update run_glue (#9442): reflect the code review * Update run_glue (#9442): auto format * Update run_glue (#9442): reflect the code review
This commit is contained in:
parent
0c9f01a8e5
commit
eabad8fd9c
@ -93,6 +93,7 @@ class DataTrainingArguments:
|
|||||||
validation_file: Optional[str] = field(
|
validation_file: Optional[str] = field(
|
||||||
default=None, metadata={"help": "A csv or a json file containing the validation data."}
|
default=None, metadata={"help": "A csv or a json file containing the validation data."}
|
||||||
)
|
)
|
||||||
|
test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.task_name is not None:
|
if self.task_name is not None:
|
||||||
@ -102,10 +103,12 @@ class DataTrainingArguments:
|
|||||||
elif self.train_file is None or self.validation_file is None:
|
elif self.train_file is None or self.validation_file is None:
|
||||||
raise ValueError("Need either a GLUE task or a training/validation file.")
|
raise ValueError("Need either a GLUE task or a training/validation file.")
|
||||||
else:
|
else:
|
||||||
extension = self.train_file.split(".")[-1]
|
train_extension = self.train_file.split(".")[-1]
|
||||||
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
||||||
extension = self.validation_file.split(".")[-1]
|
validation_extension = self.validation_file.split(".")[-1]
|
||||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
assert (
|
||||||
|
validation_extension == train_extension
|
||||||
|
), "`validation_file` should have the same extension (csv or json) as `train_file`."
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -205,16 +208,33 @@ def main():
|
|||||||
if data_args.task_name is not None:
|
if data_args.task_name is not None:
|
||||||
# Downloading and loading a dataset from the hub.
|
# Downloading and loading a dataset from the hub.
|
||||||
datasets = load_dataset("glue", data_args.task_name)
|
datasets = load_dataset("glue", data_args.task_name)
|
||||||
elif data_args.train_file.endswith(".csv"):
|
else:
|
||||||
|
# Loading a dataset from your local files.
|
||||||
|
# CSV/JSON training and evaluation files are needed.
|
||||||
|
data_files = {"train": data_args.train_file, "validation": data_args.validation_file}
|
||||||
|
|
||||||
|
# Get the test dataset: you can provide your own CSV/JSON test file (see below)
|
||||||
|
# when you use `do_predict` without specifying a GLUE benchmark task.
|
||||||
|
if training_args.do_predict:
|
||||||
|
if data_args.test_file is not None:
|
||||||
|
train_extension = data_args.train_file.split(".")[-1]
|
||||||
|
test_extension = data_args.test_file.split(".")[-1]
|
||||||
|
assert (
|
||||||
|
test_extension == train_extension
|
||||||
|
), "`test_file` should have the same extension (csv or json) as `train_file`."
|
||||||
|
data_files["test"] = data_args.test_file
|
||||||
|
else:
|
||||||
|
raise ValueError("Need either a GLUE task or a test file for `do_predict`.")
|
||||||
|
|
||||||
|
for key in data_files.keys():
|
||||||
|
logger.info(f"load a local file for {key}: {data_files[key]}")
|
||||||
|
|
||||||
|
if data_args.train_file.endswith(".csv"):
|
||||||
# Loading a dataset from local csv files
|
# Loading a dataset from local csv files
|
||||||
datasets = load_dataset(
|
datasets = load_dataset("csv", data_files=data_files)
|
||||||
"csv", data_files={"train": data_args.train_file, "validation": data_args.validation_file}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Loading a dataset from local json files
|
# Loading a dataset from local json files
|
||||||
datasets = load_dataset(
|
datasets = load_dataset("json", data_files=data_files)
|
||||||
"json", data_files={"train": data_args.train_file, "validation": data_args.validation_file}
|
|
||||||
)
|
|
||||||
# See more about loading any type of standard or custom dataset at
|
# See more about loading any type of standard or custom dataset at
|
||||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||||
|
|
||||||
@ -323,7 +343,7 @@ def main():
|
|||||||
|
|
||||||
train_dataset = datasets["train"]
|
train_dataset = datasets["train"]
|
||||||
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
|
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
|
||||||
if data_args.task_name is not None:
|
if data_args.task_name is not None or data_args.test_file is not None:
|
||||||
test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
|
test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
|
||||||
|
|
||||||
# Log a few random samples from the training set:
|
# Log a few random samples from the training set:
|
||||||
|
Loading…
Reference in New Issue
Block a user