mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
[Examples] Added support for test-file in QA examples with no trainer (#11510)
* added support for test-file * fixed typo * added suggested changes * reformatted code * modifed files * fix post processing error * Trigger CI * removed extra lines
This commit is contained in:
parent
af0692a2ca
commit
84326a28f8
1
datasets
Submodule
1
datasets
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit 8afd0ba8c27800a55ea69d9fcd702dc97d9c16d8
|
@ -172,8 +172,6 @@ accelerate test
|
|||||||
that will check everything is ready for training. Finally, you cna launch training with
|
that will check everything is ready for training. Finally, you cna launch training with
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export TASK_NAME=mrpc
|
|
||||||
|
|
||||||
accelerate launch run_qa_no_trainer.py \
|
accelerate launch run_qa_no_trainer.py \
|
||||||
--model_name_or_path bert-base-uncased \
|
--model_name_or_path bert-base-uncased \
|
||||||
--dataset_name squad \
|
--dataset_name squad \
|
||||||
|
@ -80,6 +80,9 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test_file", type=str, default=None, help="A csv or a json file containing the Prediction data."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
type=int,
|
type=int,
|
||||||
@ -202,8 +205,13 @@ def parse_args():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Sanity checks
|
# Sanity checks
|
||||||
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
|
if (
|
||||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
args.dataset_name is None
|
||||||
|
and args.train_file is None
|
||||||
|
and args.validation_file is None
|
||||||
|
and args.test_file is None
|
||||||
|
):
|
||||||
|
raise ValueError("Need either a dataset name or a training/validation/test file.")
|
||||||
else:
|
else:
|
||||||
if args.train_file is not None:
|
if args.train_file is not None:
|
||||||
extension = args.train_file.split(".")[-1]
|
extension = args.train_file.split(".")[-1]
|
||||||
@ -211,6 +219,9 @@ def parse_args():
|
|||||||
if args.validation_file is not None:
|
if args.validation_file is not None:
|
||||||
extension = args.validation_file.split(".")[-1]
|
extension = args.validation_file.split(".")[-1]
|
||||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||||
|
if args.test_file is not None:
|
||||||
|
extension = args.test_file.split(".")[-1]
|
||||||
|
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
|
||||||
|
|
||||||
if args.output_dir is not None:
|
if args.output_dir is not None:
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
@ -263,8 +274,10 @@ def main():
|
|||||||
data_files["train"] = args.train_file
|
data_files["train"] = args.train_file
|
||||||
if args.validation_file is not None:
|
if args.validation_file is not None:
|
||||||
data_files["validation"] = args.validation_file
|
data_files["validation"] = args.validation_file
|
||||||
|
if args.test_file is not None:
|
||||||
|
data_files["test"] = args.test_file
|
||||||
extension = args.train_file.split(".")[-1]
|
extension = args.train_file.split(".")[-1]
|
||||||
raw_datasets = load_dataset(extension, data_files=data_files)
|
raw_datasets = load_dataset(extension, data_files=data_files, field="data")
|
||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||||
|
|
||||||
@ -535,13 +548,15 @@ def main():
|
|||||||
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
|
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
eval_dataset_for_model = eval_dataset.remove_columns(["example_id", "offset_mapping"])
|
||||||
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
|
eval_dataloader = DataLoader(
|
||||||
|
eval_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
if args.do_predict:
|
if args.do_predict:
|
||||||
predict_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
predict_dataset_for_model = predict_dataset.remove_columns(["example_id", "offset_mapping"])
|
||||||
predict_dataloader = DataLoader(
|
predict_dataloader = DataLoader(
|
||||||
predict_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
predict_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# Post-processing:
|
# Post-processing:
|
||||||
@ -709,21 +724,21 @@ def main():
|
|||||||
start_top_index_concat = create_and_fill_np_array(all_start_top_index, eval_dataset, max_len)
|
start_top_index_concat = create_and_fill_np_array(all_start_top_index, eval_dataset, max_len)
|
||||||
end_top_log_probs_concat = create_and_fill_np_array(all_end_top_log_probs, eval_dataset, max_len)
|
end_top_log_probs_concat = create_and_fill_np_array(all_end_top_log_probs, eval_dataset, max_len)
|
||||||
end_top_index_concat = create_and_fill_np_array(all_end_top_index, eval_dataset, max_len)
|
end_top_index_concat = create_and_fill_np_array(all_end_top_index, eval_dataset, max_len)
|
||||||
all_cls_logits = np.concatenate(all_cls_logits, axis=0)
|
cls_logits_concat = np.concatenate(all_cls_logits, axis=0)
|
||||||
|
|
||||||
# delete the list of numpy arrays
|
# delete the list of numpy arrays
|
||||||
del start_top_log_probs
|
del start_top_log_probs
|
||||||
del start_top_index
|
del start_top_index
|
||||||
del end_top_log_probs
|
del end_top_log_probs
|
||||||
del end_top_index
|
del end_top_index
|
||||||
|
del cls_logits
|
||||||
|
|
||||||
eval_dataset.set_format(type=None, columns=list(eval_dataset.features.keys()))
|
|
||||||
outputs_numpy = (
|
outputs_numpy = (
|
||||||
start_top_log_probs_concat,
|
start_top_log_probs_concat,
|
||||||
start_top_index_concat,
|
start_top_index_concat,
|
||||||
end_top_log_probs_concat,
|
end_top_log_probs_concat,
|
||||||
end_top_index_concat,
|
end_top_index_concat,
|
||||||
cls_logits,
|
cls_logits_concat,
|
||||||
)
|
)
|
||||||
prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy)
|
prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy)
|
||||||
eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
|
eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
|
||||||
@ -766,21 +781,21 @@ def main():
|
|||||||
start_top_index_concat = create_and_fill_np_array(all_start_top_index, predict_dataset, max_len)
|
start_top_index_concat = create_and_fill_np_array(all_start_top_index, predict_dataset, max_len)
|
||||||
end_top_log_probs_concat = create_and_fill_np_array(all_end_top_log_probs, predict_dataset, max_len)
|
end_top_log_probs_concat = create_and_fill_np_array(all_end_top_log_probs, predict_dataset, max_len)
|
||||||
end_top_index_concat = create_and_fill_np_array(all_end_top_index, predict_dataset, max_len)
|
end_top_index_concat = create_and_fill_np_array(all_end_top_index, predict_dataset, max_len)
|
||||||
all_cls_logits = np.concatenate(all_cls_logits, axis=0)
|
cls_logits_concat = np.concatenate(all_cls_logits, axis=0)
|
||||||
|
|
||||||
# delete the list of numpy arrays
|
# delete the list of numpy arrays
|
||||||
del start_top_log_probs
|
del start_top_log_probs
|
||||||
del start_top_index
|
del start_top_index
|
||||||
del end_top_log_probs
|
del end_top_log_probs
|
||||||
del end_top_index
|
del end_top_index
|
||||||
|
del cls_logits
|
||||||
|
|
||||||
predict_dataset.set_format(type=None, columns=list(predict_dataset.features.keys()))
|
|
||||||
outputs_numpy = (
|
outputs_numpy = (
|
||||||
start_top_log_probs_concat,
|
start_top_log_probs_concat,
|
||||||
start_top_index_concat,
|
start_top_index_concat,
|
||||||
end_top_log_probs_concat,
|
end_top_log_probs_concat,
|
||||||
end_top_index_concat,
|
end_top_index_concat,
|
||||||
cls_logits,
|
cls_logits_concat,
|
||||||
)
|
)
|
||||||
|
|
||||||
prediction = post_processing_function(predict_examples, predict_dataset, outputs_numpy)
|
prediction = post_processing_function(predict_examples, predict_dataset, outputs_numpy)
|
||||||
|
@ -81,10 +81,13 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--preprocessing_num_workers", type=int, default=4, help="A csv or a json file containing the training data."
|
"--preprocessing_num_workers", type=int, default=4, help="A csv or a json file containing the training data."
|
||||||
)
|
)
|
||||||
parser.add_argument("--do_predict", action="store_true", help="Eval the question answering model")
|
parser.add_argument("--do_predict", action="store_true", help="To do prediction on the question answering model")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test_file", type=str, default=None, help="A csv or a json file containing the Prediction data."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
type=int,
|
type=int,
|
||||||
@ -231,8 +234,13 @@ def parse_args():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Sanity checks
|
# Sanity checks
|
||||||
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
|
if (
|
||||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
args.dataset_name is None
|
||||||
|
and args.train_file is None
|
||||||
|
and args.validation_file is None
|
||||||
|
and args.test_file is None
|
||||||
|
):
|
||||||
|
raise ValueError("Need either a dataset name or a training/validation/test file.")
|
||||||
else:
|
else:
|
||||||
if args.train_file is not None:
|
if args.train_file is not None:
|
||||||
extension = args.train_file.split(".")[-1]
|
extension = args.train_file.split(".")[-1]
|
||||||
@ -240,6 +248,9 @@ def parse_args():
|
|||||||
if args.validation_file is not None:
|
if args.validation_file is not None:
|
||||||
extension = args.validation_file.split(".")[-1]
|
extension = args.validation_file.split(".")[-1]
|
||||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||||
|
if args.test_file is not None:
|
||||||
|
extension = args.test_file.split(".")[-1]
|
||||||
|
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
|
||||||
|
|
||||||
if args.output_dir is not None:
|
if args.output_dir is not None:
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
@ -292,8 +303,10 @@ def main():
|
|||||||
data_files["train"] = args.train_file
|
data_files["train"] = args.train_file
|
||||||
if args.validation_file is not None:
|
if args.validation_file is not None:
|
||||||
data_files["validation"] = args.validation_file
|
data_files["validation"] = args.validation_file
|
||||||
|
if args.test_file is not None:
|
||||||
|
data_files["test"] = args.test_file
|
||||||
extension = args.train_file.split(".")[-1]
|
extension = args.train_file.split(".")[-1]
|
||||||
raw_datasets = load_dataset(extension, data_files=data_files)
|
raw_datasets = load_dataset(extension, data_files=data_files, field="data")
|
||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||||
|
|
||||||
@ -540,13 +553,15 @@ def main():
|
|||||||
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
|
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
eval_dataset_for_model = eval_dataset.remove_columns(["example_id", "offset_mapping"])
|
||||||
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
|
eval_dataloader = DataLoader(
|
||||||
|
eval_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
if args.do_predict:
|
if args.do_predict:
|
||||||
predict_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
predict_dataset_for_model = predict_dataset.remove_columns(["example_id", "offset_mapping"])
|
||||||
predict_dataloader = DataLoader(
|
predict_dataloader = DataLoader(
|
||||||
predict_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
predict_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# Post-processing:
|
# Post-processing:
|
||||||
@ -704,7 +719,6 @@ def main():
|
|||||||
del all_start_logits
|
del all_start_logits
|
||||||
del all_end_logits
|
del all_end_logits
|
||||||
|
|
||||||
eval_dataset.set_format(type=None, columns=list(eval_dataset.features.keys()))
|
|
||||||
outputs_numpy = (start_logits_concat, end_logits_concat)
|
outputs_numpy = (start_logits_concat, end_logits_concat)
|
||||||
prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy)
|
prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy)
|
||||||
eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
|
eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
|
||||||
@ -736,8 +750,6 @@ def main():
|
|||||||
del all_start_logits
|
del all_start_logits
|
||||||
del all_end_logits
|
del all_end_logits
|
||||||
|
|
||||||
# Now we need to add extra columns which we removed for post processing
|
|
||||||
predict_dataset.set_format(type=None, columns=list(predict_dataset.features.keys()))
|
|
||||||
outputs_numpy = (start_logits_concat, end_logits_concat)
|
outputs_numpy = (start_logits_concat, end_logits_concat)
|
||||||
prediction = post_processing_function(predict_examples, predict_dataset, outputs_numpy)
|
prediction = post_processing_function(predict_examples, predict_dataset, outputs_numpy)
|
||||||
predict_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
|
predict_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
|
||||||
|
Loading…
Reference in New Issue
Block a user