[Examples] Added support for test-file in QA examples with no trainer (#11510)

* added support for test-file

* fixed typo

* added suggested changes

* reformatted code

* modifed files

* fix post processing error

* Trigger CI

* removed extra lines
This commit is contained in:
Bhadresh Savani 2021-04-30 18:32:50 +05:30 committed by GitHub
parent af0692a2ca
commit 84326a28f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 52 additions and 26 deletions

1
datasets Submodule

@ -0,0 +1 @@
Subproject commit 8afd0ba8c27800a55ea69d9fcd702dc97d9c16d8

View File

@ -172,8 +172,6 @@ accelerate test
that will check everything is ready for training. Finally, you cna launch training with
```bash
export TASK_NAME=mrpc
accelerate launch run_qa_no_trainer.py \
--model_name_or_path bert-base-uncased \
--dataset_name squad \

View File

@ -80,6 +80,9 @@ def parse_args():
parser.add_argument(
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
)
parser.add_argument(
"--test_file", type=str, default=None, help="A csv or a json file containing the Prediction data."
)
parser.add_argument(
"--max_seq_length",
type=int,
@ -202,8 +205,13 @@ def parse_args():
args = parser.parse_args()
# Sanity checks
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
if (
args.dataset_name is None
and args.train_file is None
and args.validation_file is None
and args.test_file is None
):
raise ValueError("Need either a dataset name or a training/validation/test file.")
else:
if args.train_file is not None:
extension = args.train_file.split(".")[-1]
@ -211,6 +219,9 @@ def parse_args():
if args.validation_file is not None:
extension = args.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if args.test_file is not None:
extension = args.test_file.split(".")[-1]
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
@ -263,8 +274,10 @@ def main():
data_files["train"] = args.train_file
if args.validation_file is not None:
data_files["validation"] = args.validation_file
if args.test_file is not None:
data_files["test"] = args.test_file
extension = args.train_file.split(".")[-1]
raw_datasets = load_dataset(extension, data_files=data_files)
raw_datasets = load_dataset(extension, data_files=data_files, field="data")
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
@ -535,13 +548,15 @@ def main():
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
)
eval_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
eval_dataset_for_model = eval_dataset.remove_columns(["example_id", "offset_mapping"])
eval_dataloader = DataLoader(
eval_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
)
if args.do_predict:
predict_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
predict_dataset_for_model = predict_dataset.remove_columns(["example_id", "offset_mapping"])
predict_dataloader = DataLoader(
predict_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
predict_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
)
# Post-processing:
@ -709,21 +724,21 @@ def main():
start_top_index_concat = create_and_fill_np_array(all_start_top_index, eval_dataset, max_len)
end_top_log_probs_concat = create_and_fill_np_array(all_end_top_log_probs, eval_dataset, max_len)
end_top_index_concat = create_and_fill_np_array(all_end_top_index, eval_dataset, max_len)
all_cls_logits = np.concatenate(all_cls_logits, axis=0)
cls_logits_concat = np.concatenate(all_cls_logits, axis=0)
# delete the list of numpy arrays
del start_top_log_probs
del start_top_index
del end_top_log_probs
del end_top_index
del cls_logits
eval_dataset.set_format(type=None, columns=list(eval_dataset.features.keys()))
outputs_numpy = (
start_top_log_probs_concat,
start_top_index_concat,
end_top_log_probs_concat,
end_top_index_concat,
cls_logits,
cls_logits_concat,
)
prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy)
eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
@ -766,21 +781,21 @@ def main():
start_top_index_concat = create_and_fill_np_array(all_start_top_index, predict_dataset, max_len)
end_top_log_probs_concat = create_and_fill_np_array(all_end_top_log_probs, predict_dataset, max_len)
end_top_index_concat = create_and_fill_np_array(all_end_top_index, predict_dataset, max_len)
all_cls_logits = np.concatenate(all_cls_logits, axis=0)
cls_logits_concat = np.concatenate(all_cls_logits, axis=0)
# delete the list of numpy arrays
del start_top_log_probs
del start_top_index
del end_top_log_probs
del end_top_index
del cls_logits
predict_dataset.set_format(type=None, columns=list(predict_dataset.features.keys()))
outputs_numpy = (
start_top_log_probs_concat,
start_top_index_concat,
end_top_log_probs_concat,
end_top_index_concat,
cls_logits,
cls_logits_concat,
)
prediction = post_processing_function(predict_examples, predict_dataset, outputs_numpy)

View File

@ -81,10 +81,13 @@ def parse_args():
parser.add_argument(
"--preprocessing_num_workers", type=int, default=4, help="A csv or a json file containing the training data."
)
parser.add_argument("--do_predict", action="store_true", help="Eval the question answering model")
parser.add_argument("--do_predict", action="store_true", help="To do prediction on the question answering model")
parser.add_argument(
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
)
parser.add_argument(
"--test_file", type=str, default=None, help="A csv or a json file containing the Prediction data."
)
parser.add_argument(
"--max_seq_length",
type=int,
@ -231,8 +234,13 @@ def parse_args():
args = parser.parse_args()
# Sanity checks
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
if (
args.dataset_name is None
and args.train_file is None
and args.validation_file is None
and args.test_file is None
):
raise ValueError("Need either a dataset name or a training/validation/test file.")
else:
if args.train_file is not None:
extension = args.train_file.split(".")[-1]
@ -240,6 +248,9 @@ def parse_args():
if args.validation_file is not None:
extension = args.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if args.test_file is not None:
extension = args.test_file.split(".")[-1]
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
@ -292,8 +303,10 @@ def main():
data_files["train"] = args.train_file
if args.validation_file is not None:
data_files["validation"] = args.validation_file
if args.test_file is not None:
data_files["test"] = args.test_file
extension = args.train_file.split(".")[-1]
raw_datasets = load_dataset(extension, data_files=data_files)
raw_datasets = load_dataset(extension, data_files=data_files, field="data")
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
@ -540,13 +553,15 @@ def main():
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
)
eval_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
eval_dataset_for_model = eval_dataset.remove_columns(["example_id", "offset_mapping"])
eval_dataloader = DataLoader(
eval_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
)
if args.do_predict:
predict_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
predict_dataset_for_model = predict_dataset.remove_columns(["example_id", "offset_mapping"])
predict_dataloader = DataLoader(
predict_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
predict_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
)
# Post-processing:
@ -704,7 +719,6 @@ def main():
del all_start_logits
del all_end_logits
eval_dataset.set_format(type=None, columns=list(eval_dataset.features.keys()))
outputs_numpy = (start_logits_concat, end_logits_concat)
prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy)
eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
@ -736,8 +750,6 @@ def main():
del all_start_logits
del all_end_logits
# Now we need to add extra columns which we removed for post processing
predict_dataset.set_format(type=None, columns=list(predict_dataset.features.keys()))
outputs_numpy = (start_logits_concat, end_logits_concat)
prediction = post_processing_function(predict_examples, predict_dataset, outputs_numpy)
predict_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)