mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
[Examples] Added predict stage and Updated Example Template (#10868)
* added predict stage * added test keyword in exception message * removed example specific saving predictions * fixed f-string error * removed extra line Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
parent
fb2b89840b
commit
7ef40120a0
@ -207,14 +207,22 @@ def main():
|
|||||||
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
||||||
# download the dataset.
|
# download the dataset.
|
||||||
# Downloading and loading xnli dataset from the hub.
|
# Downloading and loading xnli dataset from the hub.
|
||||||
if model_args.train_language is None:
|
if training_args.do_train:
|
||||||
train_dataset = load_dataset("xnli", model_args.language, split="train")
|
if model_args.train_language is None:
|
||||||
else:
|
train_dataset = load_dataset("xnli", model_args.language, split="train")
|
||||||
train_dataset = load_dataset("xnli", model_args.train_language, split="train")
|
else:
|
||||||
|
train_dataset = load_dataset("xnli", model_args.train_language, split="train")
|
||||||
|
label_list = train_dataset.features["label"].names
|
||||||
|
|
||||||
|
if training_args.do_eval:
|
||||||
|
eval_dataset = load_dataset("xnli", model_args.language, split="validation")
|
||||||
|
label_list = eval_dataset.features["label"].names
|
||||||
|
|
||||||
|
if training_args.do_predict:
|
||||||
|
test_dataset = load_dataset("xnli", model_args.language, split="test")
|
||||||
|
label_list = test_dataset.features["label"].names
|
||||||
|
|
||||||
eval_dataset = load_dataset("xnli", model_args.language, split="validation")
|
|
||||||
# Labels
|
# Labels
|
||||||
label_list = train_dataset.features["label"].names
|
|
||||||
num_labels = len(label_list)
|
num_labels = len(label_list)
|
||||||
|
|
||||||
# Load pretrained model and tokenizer
|
# Load pretrained model and tokenizer
|
||||||
@ -271,6 +279,9 @@ def main():
|
|||||||
batched=True,
|
batched=True,
|
||||||
load_from_cache_file=not data_args.overwrite_cache,
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
)
|
)
|
||||||
|
# Log a few random samples from the training set:
|
||||||
|
for index in random.sample(range(len(train_dataset)), 3):
|
||||||
|
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||||
|
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
if data_args.max_val_samples is not None:
|
if data_args.max_val_samples is not None:
|
||||||
@ -281,9 +292,14 @@ def main():
|
|||||||
load_from_cache_file=not data_args.overwrite_cache,
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log a few random samples from the training set:
|
if training_args.do_predict:
|
||||||
for index in random.sample(range(len(train_dataset)), 3):
|
if data_args.max_test_samples is not None:
|
||||||
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||||
|
test_dataset = test_dataset.map(
|
||||||
|
preprocess_function,
|
||||||
|
batched=True,
|
||||||
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
|
)
|
||||||
|
|
||||||
# Get the metric function
|
# Get the metric function
|
||||||
metric = load_metric("xnli")
|
metric = load_metric("xnli")
|
||||||
@ -307,7 +323,7 @@ def main():
|
|||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset if training_args.do_train else None,
|
||||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||||
compute_metrics=compute_metrics,
|
compute_metrics=compute_metrics,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -346,6 +362,26 @@ def main():
|
|||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
|
# Prediction
|
||||||
|
if training_args.do_predict:
|
||||||
|
logger.info("*** Predict ***")
|
||||||
|
predictions, labels, metrics = trainer.predict(test_dataset)
|
||||||
|
|
||||||
|
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
||||||
|
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
||||||
|
|
||||||
|
trainer.log_metrics("test", metrics)
|
||||||
|
trainer.save_metrics("test", metrics)
|
||||||
|
|
||||||
|
predictions = np.argmax(predictions, axis=1)
|
||||||
|
output_test_file = os.path.join(training_args.output_dir, "test_predictions.txt")
|
||||||
|
if trainer.is_world_process_zero():
|
||||||
|
with open(output_test_file, "w") as writer:
|
||||||
|
writer.write("index\tprediction\n")
|
||||||
|
for index, item in enumerate(predictions):
|
||||||
|
item = label_list[item]
|
||||||
|
writer.write(f"{index}\t{item}\n")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -139,6 +139,10 @@ class DataTrainingArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
||||||
)
|
)
|
||||||
|
test_file: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "An optional input test data file to predict the label on (a text file)."},
|
||||||
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||||
)
|
)
|
||||||
@ -160,10 +164,22 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
max_test_samples: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||||
|
"value if set."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
if (
|
||||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
self.dataset_name is None
|
||||||
|
and self.train_file is None
|
||||||
|
and self.validation_file is None
|
||||||
|
and self.test_file is None
|
||||||
|
):
|
||||||
|
raise ValueError("Need either a dataset name or a training/validation/test file.")
|
||||||
else:
|
else:
|
||||||
if self.train_file is not None:
|
if self.train_file is not None:
|
||||||
extension = self.train_file.split(".")[-1]
|
extension = self.train_file.split(".")[-1]
|
||||||
@ -171,6 +187,9 @@ class DataTrainingArguments:
|
|||||||
if self.validation_file is not None:
|
if self.validation_file is not None:
|
||||||
extension = self.validation_file.split(".")[-1]
|
extension = self.validation_file.split(".")[-1]
|
||||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
||||||
|
if self.test_file is not None:
|
||||||
|
extension = self.test_file.split(".")[-1]
|
||||||
|
assert extension in ["csv", "json", "txt"], "`test_file` should be a csv, a json or a txt file."
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -238,9 +257,13 @@ def main():
|
|||||||
data_files = {}
|
data_files = {}
|
||||||
if data_args.train_file is not None:
|
if data_args.train_file is not None:
|
||||||
data_files["train"] = data_args.train_file
|
data_files["train"] = data_args.train_file
|
||||||
|
extension = data_args.train_file.split(".")[-1]
|
||||||
if data_args.validation_file is not None:
|
if data_args.validation_file is not None:
|
||||||
data_files["validation"] = data_args.validation_file
|
data_files["validation"] = data_args.validation_file
|
||||||
extension = data_args.train_file.split(".")[-1]
|
extension = data_args.validation_file.split(".")[-1]
|
||||||
|
if data_args.test_file is not None:
|
||||||
|
data_files["test"] = data_args.test_file
|
||||||
|
extension = data_args.test_file.split(".")[-1]
|
||||||
if extension == "txt":
|
if extension == "txt":
|
||||||
extension = "text"
|
extension = "text"
|
||||||
datasets = load_dataset(extension, data_files=data_files)
|
datasets = load_dataset(extension, data_files=data_files)
|
||||||
@ -326,8 +349,10 @@ def main():
|
|||||||
# First we tokenize all the texts.
|
# First we tokenize all the texts.
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
column_names = datasets["train"].column_names
|
column_names = datasets["train"].column_names
|
||||||
else:
|
elif training_args.do_eval:
|
||||||
column_names = datasets["validation"].column_names
|
column_names = datasets["validation"].column_names
|
||||||
|
elif training_args.do_predict:
|
||||||
|
column_names = datasets["test"].column_names
|
||||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||||
|
|
||||||
def tokenize_function(examples):
|
def tokenize_function(examples):
|
||||||
@ -365,6 +390,22 @@ def main():
|
|||||||
load_from_cache_file=not data_args.overwrite_cache,
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if training_args.do_predict:
|
||||||
|
if "test" not in datasets:
|
||||||
|
raise ValueError("--do_predict requires a test dataset")
|
||||||
|
test_dataset = datasets["test"]
|
||||||
|
# Selecting samples from dataset
|
||||||
|
if data_args.max_test_samples is not None:
|
||||||
|
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||||
|
# tokenize test dataset
|
||||||
|
test_dataset = test_dataset.map(
|
||||||
|
tokenize_function,
|
||||||
|
batched=True,
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
remove_columns=[text_column_name],
|
||||||
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
|
)
|
||||||
|
|
||||||
# Data collator
|
# Data collator
|
||||||
data_collator=default_data_collator if not training_args.fp16 else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
data_collator=default_data_collator if not training_args.fp16 else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||||
|
|
||||||
@ -420,6 +461,18 @@ def main():
|
|||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
|
# Prediction
|
||||||
|
if training_args.do_predict:
|
||||||
|
logger.info("*** Predict ***")
|
||||||
|
predictions, labels, metrics = trainer.predict(test_dataset)
|
||||||
|
|
||||||
|
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
||||||
|
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
||||||
|
|
||||||
|
trainer.log_metrics("test", metrics)
|
||||||
|
trainer.save_metrics("test", metrics)
|
||||||
|
|
||||||
|
# write custom code for saving predictions according to task
|
||||||
|
|
||||||
def _mp_fn(index):
|
def _mp_fn(index):
|
||||||
# For xla_spawn (TPUs)
|
# For xla_spawn (TPUs)
|
||||||
|
Loading…
Reference in New Issue
Block a user