New run glue script (#7917)

* Start simplification

* More progress

* Finished script

* Address comments and update tests instructions

* Wrong test

* Accept files as inputs and fix test

* Update src/transformers/trainer_utils.py

Co-authored-by: Julien Chaumond <chaumond@gmail.com>

* Fix labels and add combined score

* Add special labels

* Update TPU command

* Revert to old label strategy

* Use model labels

* Fix for STT-B

* Styling

* Apply suggestions from code review

Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>

* Code styling

* Fix review comments

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger 2020-10-22 11:42:22 -04:00 committed by GitHub
parent 18ce6b8ff3
commit 2e5052d4f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 331 additions and 170 deletions

View File

@ -67,10 +67,10 @@ class ExamplesTests(TestCasePlus):
testargs = f"""
run_glue.py
--model_name_or_path distilbert-base-uncased
--data_dir ./tests/fixtures/tests_samples/MRPC/
--output_dir {tmp_dir}
--overwrite_output_dir
--task_name mrpc
--train_file ./tests/fixtures/tests_samples/MRPC/train.csv
--validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
--do_train
--do_eval
--per_device_train_batch_size=2

View File

@ -44,8 +44,7 @@ class TorchXLAExamplesTests(unittest.TestCase):
transformers/examples/text-classification/run_glue.py
--do_train
--do_eval
--task_name=MRPC
--data_dir=/datasets/glue_data/MRPC
--task_name=mrpc
--cache_dir=./cache_dir
--num_train_epochs=1
--max_seq_length=128

View File

@ -74,18 +74,10 @@ between different runs. We report the median on 5 runs (with different seeds) fo
| WNLI | Accuracy | 45.07 |
Some of these results are significantly different from the ones reported on the test set
of GLUE benchmark on the website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the webite.
Before running any one of these GLUE tasks you should download the
[GLUE data](https://gluebenchmark.com/tasks) by running the following lines at the root of the repo
```
python utils/download_glue_data.py --data_dir /path/to/glue --tasks all
```
after replacing *path/to/glue* with a value that you like. Then you can run
of GLUE benchmark on the website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the
website.
```bash
export GLUE_DIR=/path/to/glue
export TASK_NAME=MRPC
python run_glue.py \
@ -93,7 +85,6 @@ python run_glue.py \
--task_name $TASK_NAME \
--do_train \
--do_eval \
--data_dir $GLUE_DIR/$TASK_NAME \
--max_seq_length 128 \
--per_device_train_batch_size 32 \
--learning_rate 2e-5 \
@ -114,69 +105,33 @@ since the data processor for each task inherits from the base class DataProcesso
## Running on TPUs in PyTorch
**Update**: read the more up-to-date [Running on TPUs](../README.md#running-on-tpus) in the main README.md instead.
Even when running PyTorch, you can accelerate your workloads on Google's TPUs, using `pytorch/xla`. For information on how to setup your TPU environment refer to the
Even when running PyTorch, you can accelerate your workloads on Google's TPUs, using `pytorch/xla`. For information on
how to setup your TPU environment refer to the
[pytorch/xla README](https://github.com/pytorch/xla/blob/master/README.md).
The following are some examples of running the `*_tpu.py` finetuning scripts on TPUs. All steps for data preparation are
identical to your normal GPU + Huggingface setup.
For running your GLUE task on MNLI dataset you can run something like the following:
For running your GLUE task on MNLI dataset you can run something like the following form the root of the transformers
repo:
```
export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
export GLUE_DIR=/path/to/glue
export TASK_NAME=MNLI
python run_glue_tpu.py \
--model_name_or_path bert-base-cased \
--task_name $TASK_NAME \
python examples/xla_spawn.py \
--num_cores=8 \
transformers/examples/text-classification/run_glue.py \
--do_train \
--do_eval \
--data_dir $GLUE_DIR/$TASK_NAME \
--max_seq_length 128 \
--train_batch_size 32 \
--learning_rate 3e-5 \
--num_train_epochs 3.0 \
--output_dir /tmp/$TASK_NAME \
--task_name=mrpc \
--num_train_epochs=3 \
--max_seq_length=128 \
--learning_rate=5e-5 \
--output_dir=/tmp/mrpc \
--overwrite_output_dir \
--logging_steps 50 \
--save_steps 200 \
--num_cores=8
--logging_steps=5 \
--save_steps=5 \
--tpu_metrics_debug \
--model_name_or_path=bert-base-cased \
--per_device_train_batch_size=64 \
--per_device_eval_batch_size=64
```
### MRPC
#### Fine-tuning example
The following examples fine-tune BERT on the Microsoft Research Paraphrase Corpus (MRPC) corpus and runs in less
than 10 minutes on a single K-80 and in 27 seconds (!) on single tesla V100 16GB with apex installed.
Before running any one of these GLUE tasks you should download the
[GLUE data](https://gluebenchmark.com/tasks) by running
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
and unpack it to some directory `$GLUE_DIR`.
```bash
export GLUE_DIR=/path/to/glue
python run_glue.py \
--model_name_or_path bert-base-cased \
--task_name MRPC \
--do_train \
--do_eval \
--data_dir $GLUE_DIR/MRPC/ \
--max_seq_length 128 \
--per_device_train_batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 3.0 \
--output_dir /tmp/mrpc_output/
```
Our test ran on a few seeds with [the original implementation hyper-
parameters](https://github.com/google-research/bert#sentence-and-sentence-pair-classification-tasks) gave evaluation
results between 84% and 88%.
#### Using Apex and mixed-precision
@ -184,14 +139,12 @@ Using Apex and 16 bit precision, the fine-tuning on MRPC only takes 27 seconds.
[apex](https://github.com/NVIDIA/apex), then run the following example:
```bash
export GLUE_DIR=/path/to/glue
python run_glue.py \
--model_name_or_path bert-base-cased \
--task_name MRPC \
--do_train \
--do_eval \
--data_dir $GLUE_DIR/MRPC/ \
--max_seq_length 128 \
--per_device_train_batch_size 32 \
--learning_rate 2e-5 \
@ -206,15 +159,13 @@ Here is an example using distributed training on 8 V100 GPUs. The model used is
reaches F1 > 92 on MRPC.
```bash
export GLUE_DIR=/path/to/glue
python -m torch.distributed.launch \
--nproc_per_node 8 run_glue.py \
--model_name_or_path bert-base-cased \
--task_name MRPC \
--task_name mrpc \
--do_train \
--do_eval \
--data_dir $GLUE_DIR/MRPC/ \
--max_seq_length 128 \
--per_device_train_batch_size 8 \
--learning_rate 2e-5 \
@ -246,7 +197,6 @@ python -m torch.distributed.launch \
--task_name mnli \
--do_train \
--do_eval \
--data_dir $GLUE_DIR/MNLI/ \
--max_seq_length 128 \
--per_device_train_batch_size 8 \
--learning_rate 2e-5 \
@ -272,7 +222,9 @@ The results are the following:
# Run PyTorch version using PyTorch-Lightning
Run `bash run_pl.sh` from the `glue` directory. This will also install `pytorch-lightning` and the requirements in `examples/requirements.txt`. It is a shell pipeline that will automatically download, pre-process the data and run the specified models. Logs are saved in `lightning_logs` directory.
Run `bash run_pl.sh` from the `glue` directory. This will also install `pytorch-lightning` and the requirements in
`examples/requirements.txt`. It is a shell pipeline that will automatically download, preprocess the data and run the
specified models. Logs are saved in `lightning_logs` directory.
Pass `--gpus` flag to change the number of GPUs. Default uses 1. At the end, the expected results are:

View File

@ -14,33 +14,101 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning the library models for sequence classification on GLUE."""
# You can also adapt this script on your own text classification task. Pointers for this are left as comments.
import dataclasses
import logging
import os
import random
import sys
from dataclasses import dataclass, field
from typing import Callable, Dict, Optional
from typing import Optional
import numpy as np
from datasets import load_dataset, load_metric
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GlueDataset
from transformers import GlueDataTrainingArguments as DataTrainingArguments
import transformers
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
EvalPrediction,
HfArgumentParser,
PretrainedConfig,
Trainer,
TrainingArguments,
glue_compute_metrics,
glue_output_modes,
glue_tasks_num_labels,
default_data_collator,
set_seed,
)
from transformers.trainer_utils import is_main_process
task_to_keys = {
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
"mrpc": ("sentence1", "sentence2"),
"qnli": ("question", "sentence"),
"qqp": ("question1", "question2"),
"rte": ("sentence1", "sentence2"),
"sst2": ("sentence", None),
"stsb": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"),
}
logger = logging.getLogger(__name__)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""
task_name: Optional[str] = field(
default=None,
metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
)
max_seq_length: int = field(
default=128,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
)
pad_to_max_length: bool = field(
default=True,
metadata={
"help": "Whether to pad all samples to `max_seq_length`. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
},
)
train_file: Optional[str] = field(
default=None, metadata={"help": "A csv or a json file containing the training data."}
)
validation_file: Optional[str] = field(
default=None, metadata={"help": "A csv or a json file containing the validation data."}
)
def __post_init__(self):
if self.task_name is not None:
self.task_name = self.task_name.lower()
if self.task_name not in task_to_keys.keys():
raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys()))
elif self.train_file is None or self.validation_file is None:
raise ValueError("Need either a GLUE task or a training/validation file.")
else:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
@dataclass
class ModelArguments:
"""
@ -59,6 +127,10 @@ class ModelArguments:
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
def main():
@ -67,7 +139,6 @@ def main():
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
@ -82,40 +153,82 @@ def main():
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
# Set seed
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank):
transformers.utils.logging.set_verbosity_info()
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
set_seed(training_args.seed)
try:
num_labels = glue_tasks_num_labels[data_args.task_name]
output_mode = glue_output_modes[data_args.task_name]
except KeyError:
raise ValueError("Task not found: %s" % (data_args.task_name))
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub
#
# For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
# sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
# label if at least two columns are provided.
#
# If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
# single column. You can easily tweak this behavior (see below)
#
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
if data_args.task_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset("glue", data_args.task_name)
elif data_args.train_file.endswith(".csv"):
# Loading a dataset from local csv files
datasets = load_dataset(
"csv", data_files={"train": data_args.train_file, "validation": data_args.validation_file}
)
else:
# Loading a dataset from local json files
datasets = load_dataset(
"json", data_files={"train": data_args.train_file, "validation": data_args.validation_file}
)
# See more about loading any type of standard or custom dataset at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# Labels
if data_args.task_name is not None:
is_regression = data_args.task_name == "stsb"
if not is_regression:
label_list = datasets["train"].features["label"].names
num_labels = len(label_list)
else:
num_labels = 1
else:
# Trying to have good defaults here, don't hesitate to tweak to your needs.
is_regression = datasets["train"].features["label"].dtype in ["float32", "float64"]
if is_regression:
num_labels = 1
else:
# A useful fast method:
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
label_list = datasets["train"].unique("label")
label_list.sort() # Let's sort it for determinism
num_labels = len(label_list)
# Load pretrained model and tokenizer
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
@ -125,6 +238,7 @@ def main():
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
)
model = AutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path,
@ -133,39 +247,103 @@ def main():
cache_dir=model_args.cache_dir,
)
# Get datasets
train_dataset = (
GlueDataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
)
eval_dataset = (
GlueDataset(data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir)
if training_args.do_eval
else None
)
test_dataset = (
GlueDataset(data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir)
if training_args.do_predict
else None
)
# Preprocessing the datasets
if data_args.task_name is not None:
sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
else:
# Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
non_label_column_names = [name for name in datasets["train"].column_names if name != "label"]
if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
sentence1_key, sentence2_key = "sentence1", "sentence2"
else:
if len(non_label_column_names) >= 2:
sentence1_key, sentence2_key = non_label_column_names[:2]
else:
sentence1_key, sentence2_key = non_label_column_names[0], None
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
def compute_metrics_fn(p: EvalPrediction):
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
if output_mode == "classification":
preds = np.argmax(preds, axis=1)
else: # regression
preds = np.squeeze(preds)
return glue_compute_metrics(task_name, preds, p.label_ids)
# Padding strategy
if data_args.pad_to_max_length:
padding = "max_length"
max_length = data_args.max_seq_length
else:
# We will pad later, dynamically at batch creation, to the max sequence length in each batch
padding = False
max_length = None
return compute_metrics_fn
# Some models have set the order of the labels to use, so let's make sure we do use it.
label_to_id = None
if (
model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
and data_args.task_name is not None
and is_regression
):
# Some have all caps in their config, some don't.
label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
else:
logger.warn(
"Your model seems to have been trained with labels, but they don't match the dataset: ",
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
"\nIgnoring the model labels as a result.",
)
elif data_args.task_name is None:
label_to_id = {v: i for i, v in enumerate(label_list)}
def preprocess_function(examples):
# Tokenize the texts
args = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
result = tokenizer(*args, padding=padding, max_length=max_length, truncation=True)
# Map labels to IDs (not necessary for GLUE tasks)
if label_to_id is not None and "label" in examples:
result["label"] = [label_to_id[l] for l in examples["label"]]
return result
datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache)
train_dataset = datasets["train"]
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
if data_args.task_name is not None:
test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
# Get the metric function
if data_args.task_name is not None:
metric = load_metric("glue", data_args.task_name)
# TODO: When datasets metrics include regular accuracy, make an else here and remove special branch from
# compute_metrics
# You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
def compute_metrics(p: EvalPrediction):
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
if data_args.task_name is not None:
result = metric.compute(predictions=preds, references=p.label_ids)
if len(result) > 1:
result["combined_score"] = np.mean(list(result.values())).item()
return result
elif is_regression:
return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
else:
return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
# Initialize our Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=build_compute_metrics_fn(data_args.task_name),
eval_dataset=eval_dataset if training_args.do_eval else None,
compute_metrics=compute_metrics,
tokenizer=tokenizer,
# Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
data_collator=default_data_collator if data_args.pad_to_max_length else None,
)
# Training
@ -173,11 +351,7 @@ def main():
trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
)
trainer.save_model()
# For convenience, we also re-save the tokenizer to the same directory,
# so that you can share your model easily on huggingface.co/models =)
if trainer.is_world_master():
tokenizer.save_pretrained(training_args.output_dir)
trainer.save_model() # Saves the tokenizer too for easy upload
# Evaluation
eval_results = {}
@ -185,56 +359,52 @@ def main():
logger.info("*** Evaluate ***")
# Loop to handle MNLI double evaluation (matched, mis-matched)
tasks = [data_args.task_name]
eval_datasets = [eval_dataset]
if data_args.task_name == "mnli":
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
eval_datasets.append(
GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir)
)
tasks.append("mnli-mm")
eval_datasets.append(datasets["validation_mismatched"])
for eval_dataset in eval_datasets:
trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name)
for eval_dataset, task in zip(eval_datasets, tasks):
eval_result = trainer.evaluate(eval_dataset=eval_dataset)
output_eval_file = os.path.join(
training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt"
)
if trainer.is_world_master():
output_eval_file = os.path.join(training_args.output_dir, f"eval_results_{task}.txt")
if trainer.is_world_process_zero():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name))
logger.info(f"***** Eval results {task} *****")
for key, value in eval_result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
eval_results.update(eval_result)
if training_args.do_predict:
logging.info("*** Test ***")
logger.info("*** Test ***")
# Loop to handle MNLI double evaluation (matched, mis-matched)
tasks = [data_args.task_name]
test_datasets = [test_dataset]
if data_args.task_name == "mnli":
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
test_datasets.append(
GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir)
)
tasks.append("mnli-mm")
test_datasets.append(datasets["test_mismatched"])
for test_dataset in test_datasets:
for test_dataset, task in zip(test_datasets, tasks):
# Removing the `label` columns because it contains -1 and Trainer won't like that.
test_dataset.remove_columns_("label")
predictions = trainer.predict(test_dataset=test_dataset).predictions
if output_mode == "classification":
predictions = np.argmax(predictions, axis=1)
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
output_test_file = os.path.join(
training_args.output_dir, f"test_results_{test_dataset.args.task_name}.txt"
)
if trainer.is_world_master():
output_test_file = os.path.join(training_args.output_dir, f"test_results_{task}.txt")
if trainer.is_world_process_zero():
with open(output_test_file, "w") as writer:
logger.info("***** Test results {} *****".format(test_dataset.args.task_name))
logger.info(f"***** Test results {task} *****")
writer.write("index\tprediction\n")
for index, item in enumerate(predictions):
if output_mode == "regression":
writer.write("%d\t%3.3f\n" % (index, item))
if is_regression:
writer.write(f"{index}\t{item:3.3f}\n")
else:
item = test_dataset.get_labels()[item]
writer.write("%d\t%s\n" % (index, item))
item = label_list[item]
writer.write(f"{index}\t{item}\n")
return eval_results

View File

@ -854,8 +854,6 @@ class Trainer:
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
@ -1173,7 +1171,7 @@ class Trainer:
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir)
if self.tokenizer is not None:
if self.tokenizer is not None and self.is_world_process_zero():
self.tokenizer.save_pretrained(output_dir)
def _save(self, output_dir: Optional[str] = None):
@ -1188,7 +1186,7 @@ class Trainer:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir)
if self.tokenizer is not None:
if self.tokenizer is not None and self.is_world_process_zero():
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
@ -1272,6 +1270,7 @@ class Trainer:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
return output.metrics
def predict(self, test_dataset: Dataset) -> PredictionOutput:

View File

@ -22,7 +22,7 @@ from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
import numpy as np
from .file_utils import is_tf_available, is_torch_available
from .file_utils import is_tf_available, is_torch_available, is_torch_tpu_available
from .tokenization_utils_base import ExplicitEnum
@ -157,3 +157,30 @@ default_hp_space = {
HPSearchBackend.OPTUNA: default_hp_space_optuna,
HPSearchBackend.RAY: default_hp_space_ray,
}
def is_main_process(local_rank):
"""
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
`local_rank`.
"""
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
return xm.get_ordinal() == 0
return local_rank in [-1, 0]
def total_processes_number(local_rank):
"""
Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
"""
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
return xm.xrt_world_size()
elif local_rank != -1 and is_torch_available():
import torch
return torch.distributed.get_world_size()
return 1

View File

@ -0,0 +1,7 @@
label,sentence1,sentence2
equivalent,He said the foodservice pie business doesn 't fit the company 's long-term growth strategy .,""" The foodservice pie business does not fit our long-term growth strategy ."
not_equivalent,Magnarelli said Racicot hated the Iraqi regime and looked forward to using his long years of training in the war .,"His wife said he was "" 100 percent behind George Bush "" and looked forward to using his years of training in the war ."
not_equivalent,"The dollar was at 116.92 yen against the yen , flat on the session , and at 1.2891 against the Swiss franc , also flat .","The dollar was at 116.78 yen JPY = , virtually flat on the session , and at 1.2871 against the Swiss franc CHF = , down 0.1 percent ."
equivalent,The AFL-CIO is waiting until October to decide if it will endorse a candidate .,The AFL-CIO announced Wednesday that it will decide in October whether to endorse a candidate before the primaries .
not_equivalent,No dates have been set for the civil or the criminal trial .,"No dates have been set for the criminal or civil cases , but Shanley has pleaded not guilty ."
equivalent,Wal-Mart said it would check all of its million-plus domestic workers to ensure they were legally employed .,It has also said it would review all of its domestic employees more than 1 million to ensure they have legal status .
1 label sentence1 sentence2
2 equivalent He said the foodservice pie business doesn 't fit the company 's long-term growth strategy . " The foodservice pie business does not fit our long-term growth strategy .
3 not_equivalent Magnarelli said Racicot hated the Iraqi regime and looked forward to using his long years of training in the war . His wife said he was " 100 percent behind George Bush " and looked forward to using his years of training in the war .
4 not_equivalent The dollar was at 116.92 yen against the yen , flat on the session , and at 1.2891 against the Swiss franc , also flat . The dollar was at 116.78 yen JPY = , virtually flat on the session , and at 1.2871 against the Swiss franc CHF = , down 0.1 percent .
5 equivalent The AFL-CIO is waiting until October to decide if it will endorse a candidate . The AFL-CIO announced Wednesday that it will decide in October whether to endorse a candidate before the primaries .
6 not_equivalent No dates have been set for the civil or the criminal trial . No dates have been set for the criminal or civil cases , but Shanley has pleaded not guilty .
7 equivalent Wal-Mart said it would check all of its million-plus domestic workers to ensure they were legally employed . It has also said it would review all of its domestic employees more than 1 million to ensure they have legal status .

View File

@ -0,0 +1,7 @@
label,sentence1,sentence2
equivalent,He said the foodservice pie business doesn 't fit the company 's long-term growth strategy .,""" The foodservice pie business does not fit our long-term growth strategy ."
not_equivalent,Magnarelli said Racicot hated the Iraqi regime and looked forward to using his long years of training in the war .,"His wife said he was "" 100 percent behind George Bush "" and looked forward to using his years of training in the war ."
not_equivalent,"The dollar was at 116.92 yen against the yen , flat on the session , and at 1.2891 against the Swiss franc , also flat .","The dollar was at 116.78 yen JPY = , virtually flat on the session , and at 1.2871 against the Swiss franc CHF = , down 0.1 percent ."
equivalent,The AFL-CIO is waiting until October to decide if it will endorse a candidate .,The AFL-CIO announced Wednesday that it will decide in October whether to endorse a candidate before the primaries .
not_equivalent,No dates have been set for the civil or the criminal trial .,"No dates have been set for the criminal or civil cases , but Shanley has pleaded not guilty ."
equivalent,Wal-Mart said it would check all of its million-plus domestic workers to ensure they were legally employed .,It has also said it would review all of its domestic employees more than 1 million to ensure they have legal status .
1 label sentence1 sentence2
2 equivalent He said the foodservice pie business doesn 't fit the company 's long-term growth strategy . " The foodservice pie business does not fit our long-term growth strategy .
3 not_equivalent Magnarelli said Racicot hated the Iraqi regime and looked forward to using his long years of training in the war . His wife said he was " 100 percent behind George Bush " and looked forward to using his years of training in the war .
4 not_equivalent The dollar was at 116.92 yen against the yen , flat on the session , and at 1.2891 against the Swiss franc , also flat . The dollar was at 116.78 yen JPY = , virtually flat on the session , and at 1.2871 against the Swiss franc CHF = , down 0.1 percent .
5 equivalent The AFL-CIO is waiting until October to decide if it will endorse a candidate . The AFL-CIO announced Wednesday that it will decide in October whether to endorse a candidate before the primaries .
6 not_equivalent No dates have been set for the civil or the criminal trial . No dates have been set for the criminal or civil cases , but Shanley has pleaded not guilty .
7 equivalent Wal-Mart said it would check all of its million-plus domestic workers to ensure they were legally employed . It has also said it would review all of its domestic employees more than 1 million to ensure they have legal status .