mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Examples] Added support for test-file in QA examples with no trainer (#11510)
* added support for test-file * fixed typo * added suggested changes * reformatted code * modifed files * fix post processing error * Trigger CI * removed extra lines
This commit is contained in:
parent
af0692a2ca
commit
84326a28f8
1
datasets
Submodule
1
datasets
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 8afd0ba8c27800a55ea69d9fcd702dc97d9c16d8
|
@ -172,8 +172,6 @@ accelerate test
|
||||
that will check everything is ready for training. Finally, you cna launch training with
|
||||
|
||||
```bash
|
||||
export TASK_NAME=mrpc
|
||||
|
||||
accelerate launch run_qa_no_trainer.py \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--dataset_name squad \
|
||||
|
@ -80,6 +80,9 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_file", type=str, default=None, help="A csv or a json file containing the Prediction data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
type=int,
|
||||
@ -202,8 +205,13 @@ def parse_args():
|
||||
args = parser.parse_args()
|
||||
|
||||
# Sanity checks
|
||||
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
if (
|
||||
args.dataset_name is None
|
||||
and args.train_file is None
|
||||
and args.validation_file is None
|
||||
and args.test_file is None
|
||||
):
|
||||
raise ValueError("Need either a dataset name or a training/validation/test file.")
|
||||
else:
|
||||
if args.train_file is not None:
|
||||
extension = args.train_file.split(".")[-1]
|
||||
@ -211,6 +219,9 @@ def parse_args():
|
||||
if args.validation_file is not None:
|
||||
extension = args.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
if args.test_file is not None:
|
||||
extension = args.test_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
|
||||
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
@ -263,8 +274,10 @@ def main():
|
||||
data_files["train"] = args.train_file
|
||||
if args.validation_file is not None:
|
||||
data_files["validation"] = args.validation_file
|
||||
if args.test_file is not None:
|
||||
data_files["test"] = args.test_file
|
||||
extension = args.train_file.split(".")[-1]
|
||||
raw_datasets = load_dataset(extension, data_files=data_files)
|
||||
raw_datasets = load_dataset(extension, data_files=data_files, field="data")
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
|
||||
@ -535,13 +548,15 @@ def main():
|
||||
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
|
||||
)
|
||||
|
||||
eval_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
||||
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
|
||||
eval_dataset_for_model = eval_dataset.remove_columns(["example_id", "offset_mapping"])
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
||||
)
|
||||
|
||||
if args.do_predict:
|
||||
predict_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
||||
predict_dataset_for_model = predict_dataset.remove_columns(["example_id", "offset_mapping"])
|
||||
predict_dataloader = DataLoader(
|
||||
predict_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
||||
predict_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
||||
)
|
||||
|
||||
# Post-processing:
|
||||
@ -709,21 +724,21 @@ def main():
|
||||
start_top_index_concat = create_and_fill_np_array(all_start_top_index, eval_dataset, max_len)
|
||||
end_top_log_probs_concat = create_and_fill_np_array(all_end_top_log_probs, eval_dataset, max_len)
|
||||
end_top_index_concat = create_and_fill_np_array(all_end_top_index, eval_dataset, max_len)
|
||||
all_cls_logits = np.concatenate(all_cls_logits, axis=0)
|
||||
cls_logits_concat = np.concatenate(all_cls_logits, axis=0)
|
||||
|
||||
# delete the list of numpy arrays
|
||||
del start_top_log_probs
|
||||
del start_top_index
|
||||
del end_top_log_probs
|
||||
del end_top_index
|
||||
del cls_logits
|
||||
|
||||
eval_dataset.set_format(type=None, columns=list(eval_dataset.features.keys()))
|
||||
outputs_numpy = (
|
||||
start_top_log_probs_concat,
|
||||
start_top_index_concat,
|
||||
end_top_log_probs_concat,
|
||||
end_top_index_concat,
|
||||
cls_logits,
|
||||
cls_logits_concat,
|
||||
)
|
||||
prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy)
|
||||
eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
|
||||
@ -766,21 +781,21 @@ def main():
|
||||
start_top_index_concat = create_and_fill_np_array(all_start_top_index, predict_dataset, max_len)
|
||||
end_top_log_probs_concat = create_and_fill_np_array(all_end_top_log_probs, predict_dataset, max_len)
|
||||
end_top_index_concat = create_and_fill_np_array(all_end_top_index, predict_dataset, max_len)
|
||||
all_cls_logits = np.concatenate(all_cls_logits, axis=0)
|
||||
cls_logits_concat = np.concatenate(all_cls_logits, axis=0)
|
||||
|
||||
# delete the list of numpy arrays
|
||||
del start_top_log_probs
|
||||
del start_top_index
|
||||
del end_top_log_probs
|
||||
del end_top_index
|
||||
del cls_logits
|
||||
|
||||
predict_dataset.set_format(type=None, columns=list(predict_dataset.features.keys()))
|
||||
outputs_numpy = (
|
||||
start_top_log_probs_concat,
|
||||
start_top_index_concat,
|
||||
end_top_log_probs_concat,
|
||||
end_top_index_concat,
|
||||
cls_logits,
|
||||
cls_logits_concat,
|
||||
)
|
||||
|
||||
prediction = post_processing_function(predict_examples, predict_dataset, outputs_numpy)
|
||||
|
@ -81,10 +81,13 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--preprocessing_num_workers", type=int, default=4, help="A csv or a json file containing the training data."
|
||||
)
|
||||
parser.add_argument("--do_predict", action="store_true", help="Eval the question answering model")
|
||||
parser.add_argument("--do_predict", action="store_true", help="To do prediction on the question answering model")
|
||||
parser.add_argument(
|
||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_file", type=str, default=None, help="A csv or a json file containing the Prediction data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
type=int,
|
||||
@ -231,8 +234,13 @@ def parse_args():
|
||||
args = parser.parse_args()
|
||||
|
||||
# Sanity checks
|
||||
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
if (
|
||||
args.dataset_name is None
|
||||
and args.train_file is None
|
||||
and args.validation_file is None
|
||||
and args.test_file is None
|
||||
):
|
||||
raise ValueError("Need either a dataset name or a training/validation/test file.")
|
||||
else:
|
||||
if args.train_file is not None:
|
||||
extension = args.train_file.split(".")[-1]
|
||||
@ -240,6 +248,9 @@ def parse_args():
|
||||
if args.validation_file is not None:
|
||||
extension = args.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
if args.test_file is not None:
|
||||
extension = args.test_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
|
||||
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
@ -292,8 +303,10 @@ def main():
|
||||
data_files["train"] = args.train_file
|
||||
if args.validation_file is not None:
|
||||
data_files["validation"] = args.validation_file
|
||||
if args.test_file is not None:
|
||||
data_files["test"] = args.test_file
|
||||
extension = args.train_file.split(".")[-1]
|
||||
raw_datasets = load_dataset(extension, data_files=data_files)
|
||||
raw_datasets = load_dataset(extension, data_files=data_files, field="data")
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
|
||||
@ -540,13 +553,15 @@ def main():
|
||||
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
|
||||
)
|
||||
|
||||
eval_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
||||
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
|
||||
eval_dataset_for_model = eval_dataset.remove_columns(["example_id", "offset_mapping"])
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
||||
)
|
||||
|
||||
if args.do_predict:
|
||||
predict_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
||||
predict_dataset_for_model = predict_dataset.remove_columns(["example_id", "offset_mapping"])
|
||||
predict_dataloader = DataLoader(
|
||||
predict_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
||||
predict_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
||||
)
|
||||
|
||||
# Post-processing:
|
||||
@ -704,7 +719,6 @@ def main():
|
||||
del all_start_logits
|
||||
del all_end_logits
|
||||
|
||||
eval_dataset.set_format(type=None, columns=list(eval_dataset.features.keys()))
|
||||
outputs_numpy = (start_logits_concat, end_logits_concat)
|
||||
prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy)
|
||||
eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
|
||||
@ -736,8 +750,6 @@ def main():
|
||||
del all_start_logits
|
||||
del all_end_logits
|
||||
|
||||
# Now we need to add extra columns which we removed for post processing
|
||||
predict_dataset.set_format(type=None, columns=list(predict_dataset.features.keys()))
|
||||
outputs_numpy = (start_logits_concat, end_logits_concat)
|
||||
prediction = post_processing_function(predict_examples, predict_dataset, outputs_numpy)
|
||||
predict_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
|
||||
|
Loading…
Reference in New Issue
Block a user