TF version of the trainer (#4017)

* First commit to add a TF version of the trainer.

* Make the TF trainer closer to what looks the PT trainer

* Refactoring common code between the PT and TF trainer into an util file.

* Some bugfix + better similarity with the PT trainer

* Add missing class in transformers init

* Bugfix over prediction + use classification report instead of simple metrics

* Fix name error

* Fix optimization tests + style

* Apply style

* Several bugfix for multi-gpu training

* Apply style

* Apply style

* Add glue example for the TF trainer

* Several bugix + address the reviews

* Fix on the TF training args file

* Add a debug mode

* Bugfix in utils_ner.py when segment_ids is None

* Apply style

* Apply style

* Add TPU strategy

* Fix selection strategy
This commit is contained in:
Julien Plu 2020-05-06 18:56:52 +02:00 committed by GitHub
parent 25296b12aa
commit aad50151f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1206 additions and 819 deletions

View File

@ -1,628 +1,281 @@
# coding=utf-8
import collections
import datetime
import glob
import math
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Fine-tuning the library models for named entity recognition."""
import logging
import os
import re
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import numpy as np
import tensorflow as tf
from absl import app, flags, logging
from fastprogress import master_bar, progress_bar
from seqeval import metrics
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
from transformers import (
TF2_WEIGHTS_NAME,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
AutoConfig,
AutoTokenizer,
GradientAccumulator,
PreTrainedTokenizer,
EvalPrediction,
HfArgumentParser,
TFAutoModelForTokenClassification,
create_optimizer,
TFTrainer,
TFTrainingArguments,
)
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
from utils_ner import Split, TFNerDataset, get_labels
MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
logger = logging.getLogger(__name__)
flags.DEFINE_string(
"data_dir", None, "The input data dir. Should contain the .conll files (or other data files) for the task."
)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
flags.DEFINE_string(
"model_name_or_path", None, "Path to pretrained model or model identifier from huggingface.co/models",
)
flags.DEFINE_string("output_dir", None, "The output directory where the model checkpoints will be written.")
flags.DEFINE_string(
"labels", "", "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."
)
flags.DEFINE_string("config_name", None, "Pretrained config name or path if not the same as model_name")
flags.DEFINE_string("tokenizer_name", None, "Pretrained tokenizer name or path if not the same as model_name")
flags.DEFINE_string("cache_dir", None, "Where do you want to store the pre-trained models downloaded from s3")
flags.DEFINE_integer(
"max_seq_length",
128,
"The maximum total input sentence length after tokenization. "
"Sequences longer than this will be truncated, sequences shorter "
"will be padded.",
)
flags.DEFINE_string(
"tpu",
None,
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url.",
)
flags.DEFINE_integer("num_tpu_cores", 8, "Total number of TPU cores to use.")
flags.DEFINE_boolean("do_train", False, "Whether to run training.")
flags.DEFINE_boolean("do_eval", False, "Whether to run eval on the dev set.")
flags.DEFINE_boolean("do_predict", False, "Whether to run predictions on the test set.")
flags.DEFINE_boolean(
"evaluate_during_training", False, "Whether to run evaluation during training at each logging step."
)
flags.DEFINE_boolean("do_lower_case", False, "Set this flag if you are using an uncased model.")
flags.DEFINE_integer("per_device_train_batch_size", 8, "Batch size per GPU/CPU/TPU for training.")
flags.DEFINE_integer("per_device_eval_batch_size", 8, "Batch size per GPU/CPU/TPU for evaluation.")
flags.DEFINE_integer(
"gradient_accumulation_steps", 1, "Number of updates steps to accumulate before performing a backward/update pass."
)
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
flags.DEFINE_float("weight_decay", 0.0, "Weight decay if we apply some.")
flags.DEFINE_float("adam_epsilon", 1e-8, "Epsilon for Adam optimizer.")
flags.DEFINE_float("max_grad_norm", 1.0, "Max gradient norm.")
flags.DEFINE_integer("num_train_epochs", 3, "Total number of training epochs to perform.")
flags.DEFINE_integer(
"max_steps", -1, "If > 0: set total number of training steps to perform. Override num_train_epochs."
)
flags.DEFINE_integer("warmup_steps", 0, "Linear warmup over warmup_steps.")
flags.DEFINE_integer("logging_steps", 50, "Log every X updates steps.")
flags.DEFINE_integer("save_steps", 50, "Save checkpoint every X updates steps.")
flags.DEFINE_boolean(
"eval_all_checkpoints",
False,
"Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
)
flags.DEFINE_boolean("no_cuda", False, "Avoid using CUDA even if it is available")
flags.DEFINE_boolean("overwrite_output_dir", False, "Overwrite the content of the output directory")
flags.DEFINE_boolean("overwrite_cache", False, "Overwrite the cached training and evaluation sets")
flags.DEFINE_integer("seed", 42, "random seed for initialization")
flags.DEFINE_boolean("fp16", False, "Whether to use 16-bit (mixed) precision instead of 32-bit")
flags.DEFINE_string(
"gpus",
"0",
"Comma separated list of gpus devices. If only one, switch to single "
"gpu strategy, if None takes all the gpus available.",
)
def train(
args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id
):
if args["max_steps"] > 0:
num_train_steps = args["max_steps"] * args["gradient_accumulation_steps"]
args["num_train_epochs"] = 1
else:
num_train_steps = (
math.ceil(num_train_examples / train_batch_size)
// args["gradient_accumulation_steps"]
* args["num_train_epochs"]
)
writer = tf.summary.create_file_writer("/tmp/mylogs")
with strategy.scope():
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
optimizer = create_optimizer(args["learning_rate"], num_train_steps, args["warmup_steps"])
if args["fp16"]:
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, "dynamic")
loss_metric = tf.keras.metrics.Mean(name="loss", dtype=tf.float32)
gradient_accumulator = GradientAccumulator()
logging.info("***** Running training *****")
logging.info(" Num examples = %d", num_train_examples)
logging.info(" Num Epochs = %d", args["num_train_epochs"])
logging.info(" Instantaneous batch size per device = %d", args["per_device_train_batch_size"])
logging.info(
" Total train batch size (w. parallel, distributed & accumulation) = %d",
train_batch_size * args["gradient_accumulation_steps"],
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
logging.info(" Gradient Accumulation steps = %d", args["gradient_accumulation_steps"])
logging.info(" Total training steps = %d", num_train_steps)
model.summary()
@tf.function
def apply_gradients():
grads_and_vars = []
for gradient, variable in zip(gradient_accumulator.gradients, model.trainable_variables):
if gradient is not None:
scaled_gradient = gradient / (args["n_device"] * args["gradient_accumulation_steps"])
grads_and_vars.append((scaled_gradient, variable))
else:
grads_and_vars.append((gradient, variable))
optimizer.apply_gradients(grads_and_vars, args["max_grad_norm"])
gradient_accumulator.reset()
@tf.function
def train_step(train_features, train_labels):
def step_fn(train_features, train_labels):
inputs = {"attention_mask": train_features["attention_mask"], "training": True}
if "token_type_ids" in train_features:
inputs["token_type_ids"] = train_features["token_type_ids"]
with tf.GradientTape() as tape:
logits = model(train_features["input_ids"], **inputs)[0]
active_loss = tf.reshape(train_labels, (-1,)) != pad_token_label_id
active_logits = tf.boolean_mask(tf.reshape(logits, (-1, len(labels))), active_loss)
active_labels = tf.boolean_mask(tf.reshape(train_labels, (-1,)), active_loss)
cross_entropy = loss_fct(active_labels, active_logits)
loss = tf.reduce_sum(cross_entropy) * (1.0 / train_batch_size)
grads = tape.gradient(loss, model.trainable_variables)
gradient_accumulator(grads)
return cross_entropy
per_example_losses = strategy.experimental_run_v2(step_fn, args=(train_features, train_labels))
mean_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_example_losses, axis=0)
return mean_loss
current_time = datetime.datetime.now()
train_iterator = master_bar(range(args["num_train_epochs"]))
global_step = 0
logging_loss = 0.0
for epoch in train_iterator:
epoch_iterator = progress_bar(
train_dataset, total=num_train_steps, parent=train_iterator, display=args["n_device"] > 1
)
step = 1
with strategy.scope():
for train_features, train_labels in epoch_iterator:
loss = train_step(train_features, train_labels)
if step % args["gradient_accumulation_steps"] == 0:
strategy.experimental_run_v2(apply_gradients)
loss_metric(loss)
global_step += 1
if args["logging_steps"] > 0 and global_step % args["logging_steps"] == 0:
# Log metrics
if (
args["n_device"] == 1 and args["evaluate_during_training"]
): # Only evaluate when single GPU otherwise metrics may not average well
y_true, y_pred, eval_loss = evaluate(
args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
)
report = metrics.classification_report(y_true, y_pred, digits=4)
logging.info("Eval at step " + str(global_step) + "\n" + report)
logging.info("eval_loss: " + str(eval_loss))
precision = metrics.precision_score(y_true, y_pred)
recall = metrics.recall_score(y_true, y_pred)
f1 = metrics.f1_score(y_true, y_pred)
with writer.as_default():
tf.summary.scalar("eval_loss", eval_loss, global_step)
tf.summary.scalar("precision", precision, global_step)
tf.summary.scalar("recall", recall, global_step)
tf.summary.scalar("f1", f1, global_step)
lr = optimizer.learning_rate
learning_rate = lr(step)
with writer.as_default():
tf.summary.scalar("lr", learning_rate, global_step)
tf.summary.scalar(
"loss", (loss_metric.result() - logging_loss) / args["logging_steps"], global_step
)
logging_loss = loss_metric.result()
with writer.as_default():
tf.summary.scalar("loss", loss_metric.result(), step=step)
if args["save_steps"] > 0 and global_step % args["save_steps"] == 0:
# Save model checkpoint
output_dir = os.path.join(args["output_dir"], "checkpoint-{}".format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model.save_pretrained(output_dir)
logging.info("Saving model checkpoint to %s", output_dir)
train_iterator.child.comment = f"loss : {loss_metric.result()}"
step += 1
train_iterator.write(f"loss epoch {epoch + 1}: {loss_metric.result()}")
loss_metric.reset_states()
logging.info(" Training took time = {}".format(datetime.datetime.now() - current_time))
def evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode):
eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
eval_dataset, size = load_and_cache_examples(
args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode=mode
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
eval_dataset = strategy.experimental_distribute_dataset(eval_dataset)
preds = None
num_eval_steps = math.ceil(size / eval_batch_size)
master = master_bar(range(1))
eval_iterator = progress_bar(eval_dataset, total=num_eval_steps, parent=master, display=args["n_device"] > 1)
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
loss = 0.0
logging.info("***** Running evaluation *****")
logging.info(" Num examples = %d", size)
logging.info(" Batch size = %d", eval_batch_size)
for eval_features, eval_labels in eval_iterator:
inputs = {"attention_mask": eval_features["attention_mask"], "training": False}
if "token_type_ids" in eval_features:
inputs["token_type_ids"] = eval_features["token_type_ids"]
with strategy.scope():
logits = model(eval_features["input_ids"], **inputs)[0]
active_loss = tf.reshape(eval_labels, (-1,)) != pad_token_label_id
active_logits = tf.boolean_mask(tf.reshape(logits, (-1, len(labels))), active_loss)
active_labels = tf.boolean_mask(tf.reshape(eval_labels, (-1,)), active_loss)
cross_entropy = loss_fct(active_labels, active_logits)
loss += tf.reduce_sum(cross_entropy) * (1.0 / eval_batch_size)
if preds is None:
preds = logits.numpy()
label_ids = eval_labels.numpy()
else:
preds = np.append(preds, logits.numpy(), axis=0)
label_ids = np.append(label_ids, eval_labels.numpy(), axis=0)
preds = np.argmax(preds, axis=2)
y_pred = [[] for _ in range(label_ids.shape[0])]
y_true = [[] for _ in range(label_ids.shape[0])]
loss = loss / num_eval_steps
for i in range(label_ids.shape[0]):
for j in range(label_ids.shape[1]):
if label_ids[i, j] != pad_token_label_id:
y_pred[i].append(labels[preds[i, j] - 1])
y_true[i].append(labels[label_ids[i, j] - 1])
return y_true, y_pred, loss.numpy()
def load_cache(cached_file, tokenizer: PreTrainedTokenizer, max_seq_length):
name_to_features = {
"input_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
"attention_mask": tf.io.FixedLenFeature([max_seq_length], tf.int64),
"label_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
}
# TODO Find a cleaner way to do this.
if "token_type_ids" in tokenizer.model_input_names:
name_to_features["token_type_ids"] = tf.io.FixedLenFeature([max_seq_length], tf.int64)
def _decode_record(record):
example = tf.io.parse_single_example(record, name_to_features)
features = {}
features["input_ids"] = example["input_ids"]
features["attention_mask"] = example["attention_mask"]
if "token_type_ids" in example:
features["token_type_ids"] = example["token_type_ids"]
return features, example["label_ids"]
d = tf.data.TFRecordDataset(cached_file)
d = d.map(_decode_record, num_parallel_calls=4)
count = d.reduce(0, lambda x, _: x + 1)
return d, count.numpy()
def save_cache(features, cached_features_file):
writer = tf.io.TFRecordWriter(cached_features_file)
for (ex_index, feature) in enumerate(features):
if ex_index % 5000 == 0:
logging.info("Writing example %d of %d" % (ex_index, len(features)))
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
record_feature = collections.OrderedDict()
record_feature["input_ids"] = create_int_feature(feature.input_ids)
record_feature["attention_mask"] = create_int_feature(feature.attention_mask)
if feature.token_type_ids is not None:
record_feature["token_type_ids"] = create_int_feature(feature.token_type_ids)
record_feature["label_ids"] = create_int_feature(feature.label_ids)
tf_example = tf.train.Example(features=tf.train.Features(feature=record_feature))
writer.write(tf_example.SerializeToString())
writer.close()
def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, batch_size, mode):
drop_remainder = True if args["tpu"] or mode == "train" else False
# Load data features from cache or dataset file
cached_features_file = os.path.join(
args["data_dir"],
"cached_{}_{}_{}.tf_record".format(mode, tokenizer.__class__.__name__, str(args["max_seq_length"])),
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
use_fast: bool = field(default=False, metadata={"help": "Set this flag to use fast tokenization."})
# If you want to tweak more attributes on your tokenizer, you should do it in a distinct script,
# or just modify its tokenizer_config.json.
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
if os.path.exists(cached_features_file) and not args["overwrite_cache"]:
logging.info("Loading features from cached file %s", cached_features_file)
dataset, size = load_cache(cached_features_file, tokenizer, args["max_seq_length"])
else:
logging.info("Creating features from dataset file at %s", args["data_dir"])
examples = read_examples_from_file(args["data_dir"], mode)
features = convert_examples_to_features(
examples,
labels,
args["max_seq_length"],
tokenizer,
cls_token_at_end=bool(args["model_type"] in ["xlnet"]),
# xlnet has a cls token at the end
cls_token=tokenizer.cls_token,
cls_token_segment_id=2 if args["model_type"] in ["xlnet"] else 0,
sep_token=tokenizer.sep_token,
sep_token_extra=bool(args["model_type"] in ["roberta"]),
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
pad_on_left=bool(args["model_type"] in ["xlnet"]),
# pad on the left for xlnet
pad_token=tokenizer.pad_token_id,
pad_token_segment_id=tokenizer.pad_token_type_id,
pad_token_label_id=pad_token_label_id,
)
logging.info("Saving features into cached file %s", cached_features_file)
save_cache(features, cached_features_file)
dataset, size = load_cache(cached_features_file, tokenizer, args["max_seq_length"])
if mode == "train":
dataset = dataset.repeat()
dataset = dataset.shuffle(buffer_size=8192, seed=args["seed"])
dataset = dataset.batch(batch_size, drop_remainder)
dataset = dataset.prefetch(buffer_size=batch_size)
return dataset, size
def main(_):
logging.set_verbosity(logging.INFO)
args = flags.FLAGS.flag_values_dict()
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
data_dir: str = field(
metadata={"help": "The input data dir. Should contain the .txt files for a CoNLL-2003-formatted task."}
)
labels: Optional[str] = field(
metadata={"help": "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."}
)
max_seq_length: int = field(
default=128,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if (
os.path.exists(args["output_dir"])
and os.listdir(args["output_dir"])
and args["do_train"]
and not args["overwrite_output_dir"]
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
args["output_dir"]
)
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
if args["fp16"]:
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
if args["tpu"]:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args["tpu"])
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
args["n_device"] = args["num_tpu_cores"]
elif len(args["gpus"].split(",")) > 1:
args["n_device"] = len([f"/gpu:{gpu}" for gpu in args["gpus"].split(",")])
strategy = tf.distribute.MirroredStrategy(devices=[f"/gpu:{gpu}" for gpu in args["gpus"].split(",")])
elif args["no_cuda"]:
args["n_device"] = 1
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
else:
args["n_device"] = len(args["gpus"].split(","))
strategy = tf.distribute.OneDeviceStrategy(device="/gpu:" + args["gpus"].split(",")[0])
logging.warning(
"n_device: %s, distributed training: %s, 16-bits training: %s",
args["n_device"],
bool(args["n_device"] > 1),
args["fp16"],
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(
"n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.n_gpu,
bool(training_args.n_gpu > 1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
labels = get_labels(args["labels"])
# Prepare Token Classification task
labels = get_labels(data_args.labels)
label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
num_labels = len(labels)
pad_token_label_id = -1
# Load pretrained model and tokenizer
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
config = AutoConfig.from_pretrained(
args["config_name"] if args["config_name"] else args["model_name_or_path"],
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
cache_dir=args["cache_dir"],
id2label=label_map,
label2id={label: i for i, label in enumerate(labels)},
cache_dir=model_args.cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast,
)
logging.info("Training/evaluation parameters %s", args)
args["model_type"] = config.model_type
with training_args.strategy.scope():
model = TFAutoModelForTokenClassification.from_pretrained(
model_args.model_name_or_path,
from_pt=bool(".bin" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
)
# Get datasets
train_dataset = (
TFNerDataset(
data_dir=data_args.data_dir,
tokenizer=tokenizer,
labels=labels,
model_type=config.model_type,
max_seq_length=data_args.max_seq_length,
overwrite_cache=data_args.overwrite_cache,
mode=Split.train,
)
if training_args.do_train
else None
)
eval_dataset = (
TFNerDataset(
data_dir=data_args.data_dir,
tokenizer=tokenizer,
labels=labels,
model_type=config.model_type,
max_seq_length=data_args.max_seq_length,
overwrite_cache=data_args.overwrite_cache,
mode=Split.dev,
)
if training_args.do_eval
else None
)
def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> Tuple[List[int], List[int]]:
preds = np.argmax(predictions, axis=2)
batch_size, seq_len = preds.shape
out_label_list = [[] for _ in range(batch_size)]
preds_list = [[] for _ in range(batch_size)]
for i in range(batch_size):
for j in range(seq_len):
if label_ids[i, j] != -1:
out_label_list[i].append(label_map[label_ids[i][j]])
preds_list[i].append(label_map[preds[i][j]])
return preds_list, out_label_list
def compute_metrics(p: EvalPrediction) -> Dict:
preds_list, out_label_list = align_predictions(p.predictions, p.label_ids)
return {
"precision": precision_score(out_label_list, preds_list),
"recall": recall_score(out_label_list, preds_list),
"f1": f1_score(out_label_list, preds_list),
}
# Initialize our Trainer
trainer = TFTrainer(
model=model,
args=training_args,
train_dataset=train_dataset.get_dataset() if train_dataset else None,
eval_dataset=eval_dataset.get_dataset() if eval_dataset else None,
compute_metrics=compute_metrics,
)
# Training
if args["do_train"]:
tokenizer = AutoTokenizer.from_pretrained(
args["tokenizer_name"] if args["tokenizer_name"] else args["model_name_or_path"],
do_lower_case=args["do_lower_case"],
cache_dir=args["cache_dir"],
)
with strategy.scope():
model = TFAutoModelForTokenClassification.from_pretrained(
args["model_name_or_path"],
from_pt=bool(".bin" in args["model_name_or_path"]),
config=config,
cache_dir=args["cache_dir"],
)
train_batch_size = args["per_device_train_batch_size"] * args["n_device"]
train_dataset, num_train_examples = load_and_cache_examples(
args, tokenizer, labels, pad_token_label_id, train_batch_size, mode="train"
)
train_dataset = strategy.experimental_distribute_dataset(train_dataset)
train(
args,
strategy,
train_dataset,
tokenizer,
model,
num_train_examples,
labels,
train_batch_size,
pad_token_label_id,
)
os.makedirs(args["output_dir"], exist_ok=True)
logging.info("Saving model to %s", args["output_dir"])
model.save_pretrained(args["output_dir"])
tokenizer.save_pretrained(args["output_dir"])
if training_args.do_train:
trainer.train()
trainer.save_model()
tokenizer.save_pretrained(training_args.output_dir)
# Evaluation
if args["do_eval"]:
tokenizer = AutoTokenizer.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
checkpoints = []
results = []
results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")
if args["eval_all_checkpoints"]:
checkpoints = list(
os.path.dirname(c)
for c in sorted(
glob.glob(args["output_dir"] + "/**/" + TF2_WEIGHTS_NAME, recursive=True),
key=lambda f: int("".join(filter(str.isdigit, f)) or -1),
)
)
result = trainer.evaluate()
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
logging.info("Evaluate the following checkpoints: %s", checkpoints)
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
if len(checkpoints) == 0:
checkpoints.append(args["output_dir"])
for key, value in result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
for checkpoint in checkpoints:
global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"
results.update(result)
with strategy.scope():
model = TFAutoModelForTokenClassification.from_pretrained(checkpoint)
y_true, y_pred, eval_loss = evaluate(
args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
)
report = metrics.classification_report(y_true, y_pred, digits=4)
if global_step:
results.append({global_step + "_report": report, global_step + "_loss": eval_loss})
output_eval_file = os.path.join(args["output_dir"], "eval_results.txt")
with tf.io.gfile.GFile(output_eval_file, "w") as writer:
for res in results:
for key, val in res.items():
if "loss" in key:
logging.info(key + " = " + str(val))
writer.write(key + " = " + str(val))
writer.write("\n")
else:
logging.info(key)
logging.info("\n" + report)
writer.write(key + "\n")
writer.write(report)
writer.write("\n")
if args["do_predict"]:
tokenizer = AutoTokenizer.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
model = TFAutoModelForTokenClassification.from_pretrained(args["output_dir"])
eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
predict_dataset, _ = load_and_cache_examples(
args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test"
# Predict
if training_args.do_predict:
test_dataset = TFNerDataset(
data_dir=data_args.data_dir,
tokenizer=tokenizer,
labels=labels,
model_type=config.model_type,
max_seq_length=data_args.max_seq_length,
overwrite_cache=data_args.overwrite_cache,
mode=Split.test,
)
y_true, y_pred, pred_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="test")
output_test_results_file = os.path.join(args["output_dir"], "test_results.txt")
output_test_predictions_file = os.path.join(args["output_dir"], "test_predictions.txt")
report = metrics.classification_report(y_true, y_pred, digits=4)
with tf.io.gfile.GFile(output_test_results_file, "w") as writer:
report = metrics.classification_report(y_true, y_pred, digits=4)
predictions, label_ids, metrics = trainer.predict(test_dataset.get_dataset())
preds_list, labels_list = align_predictions(predictions, label_ids)
report = classification_report(labels_list, preds_list)
logging.info("\n" + report)
logger.info("\n%s", report)
writer.write(report)
writer.write("\n\nloss = " + str(pred_loss))
output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt")
with tf.io.gfile.GFile(output_test_predictions_file, "w") as writer:
with tf.io.gfile.GFile(os.path.join(args["data_dir"], "test.txt"), "r") as f:
with open(output_test_results_file, "w") as writer:
writer.write("%s\n" % report)
# Save predictions
output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
with open(output_test_predictions_file, "w") as writer:
with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f:
example_id = 0
for line in f:
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
writer.write(line)
if not y_pred[example_id]:
if not preds_list[example_id]:
example_id += 1
elif y_pred[example_id]:
output_line = line.split()[0] + " " + y_pred[example_id].pop(0) + "\n"
elif preds_list[example_id]:
output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n"
writer.write(output_line)
else:
logging.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])
logger.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])
return results
if __name__ == "__main__":
flags.mark_flag_as_required("data_dir")
flags.mark_flag_as_required("output_dir")
flags.mark_flag_as_required("model_name_or_path")
app.run(main)
main()

View File

@ -22,11 +22,7 @@ from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Union
import torch
from torch import nn
from torch.utils.data.dataset import Dataset
from transformers import PreTrainedTokenizer, torch_distributed_zero_first
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
logger = logging.getLogger(__name__)
@ -68,70 +64,170 @@ class Split(Enum):
test = "test"
class NerDataset(Dataset):
"""
This will be superseded by a framework-agnostic approach
soon.
"""
if is_torch_available():
import torch
from torch import nn
from torch.utils.data.dataset import Dataset
from transformers import torch_distributed_zero_first
features: List[InputFeatures]
pad_token_label_id: int = nn.CrossEntropyLoss().ignore_index
# Use cross entropy ignore_index as padding label id so that only
# real label ids contribute to the loss later.
class NerDataset(Dataset):
"""
This will be superseded by a framework-agnostic approach
soon.
"""
def __init__(
self,
data_dir: str,
tokenizer: PreTrainedTokenizer,
labels: List[str],
model_type: str,
max_seq_length: Optional[int] = None,
overwrite_cache=False,
mode: Split = Split.train,
local_rank=-1,
):
# Load data features from cache or dataset file
cached_features_file = os.path.join(
data_dir, "cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length)),
)
features: List[InputFeatures]
pad_token_label_id: int = nn.CrossEntropyLoss().ignore_index
# Use cross entropy ignore_index as padding label id so that only
# real label ids contribute to the loss later.
with torch_distributed_zero_first(local_rank):
# Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache.
def __init__(
self,
data_dir: str,
tokenizer: PreTrainedTokenizer,
labels: List[str],
model_type: str,
max_seq_length: Optional[int] = None,
overwrite_cache=False,
mode: Split = Split.train,
local_rank=-1,
):
# Load data features from cache or dataset file
cached_features_file = os.path.join(
data_dir, "cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length)),
)
if os.path.exists(cached_features_file) and not overwrite_cache:
logger.info(f"Loading features from cached file {cached_features_file}")
self.features = torch.load(cached_features_file)
else:
logger.info(f"Creating features from dataset file at {data_dir}")
examples = read_examples_from_file(data_dir, mode)
# TODO clean up all this to leverage built-in features of tokenizers
self.features = convert_examples_to_features(
examples,
labels,
max_seq_length,
tokenizer,
cls_token_at_end=bool(model_type in ["xlnet"]),
# xlnet has a cls token at the end
cls_token=tokenizer.cls_token,
cls_token_segment_id=2 if model_type in ["xlnet"] else 0,
sep_token=tokenizer.sep_token,
sep_token_extra=bool(model_type in ["roberta"]),
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
pad_on_left=bool(tokenizer.padding_side == "left"),
pad_token=tokenizer.pad_token_id,
pad_token_segment_id=tokenizer.pad_token_type_id,
pad_token_label_id=self.pad_token_label_id,
with torch_distributed_zero_first(local_rank):
# Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache.
if os.path.exists(cached_features_file) and not overwrite_cache:
logger.info(f"Loading features from cached file {cached_features_file}")
self.features = torch.load(cached_features_file)
else:
logger.info(f"Creating features from dataset file at {data_dir}")
examples = read_examples_from_file(data_dir, mode)
# TODO clean up all this to leverage built-in features of tokenizers
self.features = convert_examples_to_features(
examples,
labels,
max_seq_length,
tokenizer,
cls_token_at_end=bool(model_type in ["xlnet"]),
# xlnet has a cls token at the end
cls_token=tokenizer.cls_token,
cls_token_segment_id=2 if model_type in ["xlnet"] else 0,
sep_token=tokenizer.sep_token,
sep_token_extra=bool(model_type in ["roberta"]),
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
pad_on_left=bool(tokenizer.padding_side == "left"),
pad_token=tokenizer.pad_token_id,
pad_token_segment_id=tokenizer.pad_token_type_id,
pad_token_label_id=self.pad_token_label_id,
)
if local_rank in [-1, 0]:
logger.info(f"Saving features into cached file {cached_features_file}")
torch.save(self.features, cached_features_file)
def __len__(self):
return len(self.features)
def __getitem__(self, i) -> InputFeatures:
return self.features[i]
if is_tf_available():
import tensorflow as tf
class TFNerDataset:
"""
This will be superseded by a framework-agnostic approach
soon.
"""
features: List[InputFeatures]
pad_token_label_id: int = -1
# Use cross entropy ignore_index as padding label id so that only
# real label ids contribute to the loss later.
def __init__(
self,
data_dir: str,
tokenizer: PreTrainedTokenizer,
labels: List[str],
model_type: str,
max_seq_length: Optional[int] = None,
overwrite_cache=False,
mode: Split = Split.train,
):
examples = read_examples_from_file(data_dir, mode)
# TODO clean up all this to leverage built-in features of tokenizers
self.features = convert_examples_to_features(
examples,
labels,
max_seq_length,
tokenizer,
cls_token_at_end=bool(model_type in ["xlnet"]),
# xlnet has a cls token at the end
cls_token=tokenizer.cls_token,
cls_token_segment_id=2 if model_type in ["xlnet"] else 0,
sep_token=tokenizer.sep_token,
sep_token_extra=bool(model_type in ["roberta"]),
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
pad_on_left=bool(tokenizer.padding_side == "left"),
pad_token=tokenizer.pad_token_id,
pad_token_segment_id=tokenizer.pad_token_type_id,
pad_token_label_id=self.pad_token_label_id,
)
def gen():
for ex in self.features:
if ex.token_type_ids is None:
yield (
{"input_ids": ex.input_ids, "attention_mask": ex.attention_mask},
ex.label_ids,
)
else:
yield (
{
"input_ids": ex.input_ids,
"attention_mask": ex.attention_mask,
"token_type_ids": ex.token_type_ids,
},
ex.label_ids,
)
if "token_type_ids" not in tokenizer.model_input_names:
self.dataset = tf.data.Dataset.from_generator(
gen,
({"input_ids": tf.int32, "attention_mask": tf.int32}, tf.int64),
(
{"input_ids": tf.TensorShape([None]), "attention_mask": tf.TensorShape([None])},
tf.TensorShape([None]),
),
)
else:
self.dataset = tf.data.Dataset.from_generator(
gen,
({"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, tf.int64),
(
{
"input_ids": tf.TensorShape([None]),
"attention_mask": tf.TensorShape([None]),
"token_type_ids": tf.TensorShape([None]),
},
tf.TensorShape([None]),
),
)
if local_rank in [-1, 0]:
logger.info(f"Saving features into cached file {cached_features_file}")
torch.save(self.features, cached_features_file)
def __len__(self):
return len(self.features)
def get_dataset(self):
return self.dataset
def __getitem__(self, i) -> InputFeatures:
return self.features[i]
def __len__(self):
return len(self.features)
def __getitem__(self, i) -> InputFeatures:
return self.features[i]
def read_examples_from_file(data_dir, mode: Union[Split, str]) -> List[InputExample]:

View File

@ -1,105 +1,229 @@
import os
# coding=utf-8
""" Fine-tuning the library models for sequence classification."""
import tensorflow as tf
import tensorflow_datasets
import logging
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Optional
import numpy as np
import tensorflow_datasets as tfds
from transformers import (
BertConfig,
BertForSequenceClassification,
BertTokenizer,
TFBertForSequenceClassification,
AutoConfig,
AutoTokenizer,
EvalPrediction,
HfArgumentParser,
PreTrainedTokenizer,
TFAutoModelForSequenceClassification,
TFTrainer,
TFTrainingArguments,
glue_compute_metrics,
glue_convert_examples_to_features,
glue_output_modes,
glue_processors,
glue_tasks_num_labels,
)
# script parameters
BATCH_SIZE = 32
EVAL_BATCH_SIZE = BATCH_SIZE * 2
USE_XLA = False
USE_AMP = False
EPOCHS = 3
TASK = "mrpc"
if TASK == "sst-2":
TFDS_TASK = "sst2"
elif TASK == "sts-b":
TFDS_TASK = "stsb"
else:
TFDS_TASK = TASK
num_labels = len(glue_processors[TASK]().get_labels())
print(num_labels)
tf.config.optimizer.set_jit(USE_XLA)
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})
# Load tokenizer and model from pretrained model/vocabulary. Specify the number of labels to classify (2+: classification, 1: regression)
config = BertConfig.from_pretrained("bert-base-cased", num_labels=num_labels)
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
model = TFBertForSequenceClassification.from_pretrained("bert-base-cased", config=config)
# Load dataset via TensorFlow Datasets
data, info = tensorflow_datasets.load(f"glue/{TFDS_TASK}", with_info=True)
train_examples = info.splits["train"].num_examples
# MNLI expects either validation_matched or validation_mismatched
valid_examples = info.splits["validation"].num_examples
# Prepare dataset for GLUE as a tf.data.Dataset instance
train_dataset = glue_convert_examples_to_features(data["train"], tokenizer, max_length=128, task=TASK)
# MNLI expects either validation_matched or validation_mismatched
valid_dataset = glue_convert_examples_to_features(data["validation"], tokenizer, max_length=128, task=TASK)
train_dataset = train_dataset.shuffle(128).batch(BATCH_SIZE).repeat(-1)
valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE)
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule
opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
if USE_AMP:
# loss scaling is currently required when using mixed precision
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, "dynamic")
class Split(Enum):
train = "train"
dev = "validation"
test = "test"
if num_labels == 1:
loss = tf.keras.losses.MeanSquaredError()
else:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
def get_tfds(
task_name: str, tokenizer: PreTrainedTokenizer, max_seq_length: Optional[int] = None, mode: Split = Split.train
):
if task_name == "mnli-mm" and mode == Split.dev:
tfds_name = "mnli_mismatched"
elif task_name == "mnli-mm" and mode == Split.train:
tfds_name = "mnli"
elif task_name == "mnli" and mode == Split.dev:
tfds_name = "mnli_matched"
elif task_name == "sst-2":
tfds_name = "sst2"
elif task_name == "sts-b":
tfds_name = "stsb"
else:
tfds_name = task_name
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
model.compile(optimizer=opt, loss=loss, metrics=[metric])
ds = tfds.load("glue/" + tfds_name, split=mode.value)
# Train and evaluate using tf.keras.Model.fit()
train_steps = train_examples // BATCH_SIZE
valid_steps = valid_examples // EVAL_BATCH_SIZE
return glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
history = model.fit(
train_dataset,
epochs=EPOCHS,
steps_per_epoch=train_steps,
validation_data=valid_dataset,
validation_steps=valid_steps,
)
# Save TF2 model
os.makedirs("./save/", exist_ok=True)
model.save_pretrained("./save/")
logger = logging.getLogger(__name__)
if TASK == "mrpc":
# Load the TensorFlow model in PyTorch for inspection
# This is to demo the interoperability between the two frameworks, you don't have to
# do this in real life (you can run the inference on the TF model).
pytorch_model = BertForSequenceClassification.from_pretrained("./save/", from_tf=True)
# Quickly test a few predictions - MRPC is a paraphrasing task, let's see if our model learned the task
sentence_0 = "This research was consistent with his findings."
sentence_1 = "His findings were compatible with this research."
sentence_2 = "His findings were not compatible with this research."
inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors="pt")
inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors="pt")
@dataclass
class GlueDataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
pred_1 = pytorch_model(**inputs_1)[0].argmax().item()
pred_2 = pytorch_model(**inputs_2)[0].argmax().item()
print("sentence_1 is", "a paraphrase" if pred_1 else "not a paraphrase", "of sentence_0")
print("sentence_2 is", "a paraphrase" if pred_2 else "not a paraphrase", "of sentence_0")
Using `HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""
task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
max_seq_length: int = field(
default=128,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
def __post_init__(self):
self.task_name = self.task_name.lower()
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
use_fast: bool = field(default=False, metadata={"help": "Set this flag to use fast tokenization."})
# If you want to tweak more attributes on your tokenizer, you should do it in a distinct script,
# or just modify its tokenizer_config.json.
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, GlueDataTrainingArguments, TFTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(
"n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.n_gpu,
bool(training_args.n_gpu > 1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
try:
num_labels = glue_tasks_num_labels["mnli" if data_args.task_name == "mnli-mm" else data_args.task_name]
output_mode = glue_output_modes[data_args.task_name]
except KeyError:
raise ValueError("Task not found: %s" % (data_args.task_name))
# Load pretrained model and tokenizer
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
finetuning_task=data_args.task_name,
cache_dir=model_args.cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
)
with training_args.strategy.scope():
model = TFAutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path,
from_pt=bool(".bin" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
)
# Get datasets
train_dataset = (
get_tfds(task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length)
if training_args.do_train
else None
)
eval_dataset = (
get_tfds(
task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length, mode=Split.dev
)
if training_args.do_eval
else None
)
def compute_metrics(p: EvalPrediction) -> Dict:
if output_mode == "classification":
preds = np.argmax(p.predictions, axis=1)
elif output_mode == "regression":
preds = np.squeeze(p.predictions)
return glue_compute_metrics(data_args.task_name, preds, p.label_ids)
# Initialize our Trainer
trainer = TFTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
)
# Training
if training_args.do_train:
trainer.train()
trainer.save_model()
tokenizer.save_pretrained(training_args.output_dir)
# Evaluation
results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")
result = trainer.evaluate()
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
results.update(result)
return results
if __name__ == "__main__":
main()

View File

@ -145,7 +145,9 @@ from .tokenization_utils import PreTrainedTokenizer
from .tokenization_xlm import XLMTokenizer
from .tokenization_xlm_roberta import XLMRobertaTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
from .trainer_utils import EvalPrediction
from .training_args import TrainingArguments
from .training_args_tf import TFTrainingArguments
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -502,6 +504,9 @@ if is_tf_available():
# Optimization
from .optimization_tf import WarmUp, create_optimizer, AdamWeightDecay, GradientAccumulator
# Trainer
from .trainer_tf import TFTrainer
if not is_tf_available() and not is_torch_available():
logger.warning(

View File

@ -21,9 +21,11 @@ import tensorflow as tf
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Applys a warmup schedule on a given learning rate decay schedule."""
"""Applies a warmup schedule on a given learning rate decay schedule."""
def __init__(self, initial_learning_rate, decay_schedule_fn, warmup_steps, power=1.0, name=None):
def __init__(
self, initial_learning_rate, decay_schedule_fn, warmup_steps, power=1.0, name=None,
):
super().__init__()
self.initial_learning_rate = initial_learning_rate
self.warmup_steps = warmup_steps
@ -56,34 +58,34 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
}
def create_optimizer(init_lr, num_train_steps, num_warmup_steps):
def create_optimizer(init_lr, num_train_steps, num_warmup_steps, end_lr=0.0, optimizer_type="adamw"):
"""Creates an optimizer with learning rate schedule."""
# Implements linear decay of the learning rate.
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=init_lr, decay_steps=num_train_steps, end_learning_rate=0.0
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=init_lr, decay_steps=num_train_steps, end_learning_rate=end_lr,
)
if num_warmup_steps:
learning_rate_fn = WarmUp(
initial_learning_rate=init_lr, decay_schedule_fn=learning_rate_fn, warmup_steps=num_warmup_steps
lr_schedule = WarmUp(
initial_learning_rate=init_lr, decay_schedule_fn=lr_schedule, warmup_steps=num_warmup_steps,
)
optimizer = AdamWeightDecay(
learning_rate=learning_rate_fn,
learning_rate=lr_schedule,
weight_decay_rate=0.01,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=["layer_norm", "bias"],
)
return optimizer
class AdamWeightDecay(tf.keras.optimizers.Adam):
"""Adam enables L2 weight decay and clip_by_global_norm on gradients.
Just adding the square of the weights to the loss function is *not* the
correct way of using L2 regularization/weight decay with Adam, since that will
interact with the m and v parameters in strange ways.
Instead we want ot decay the weights in a manner that doesn't interact with
the m/v parameters. This is equivalent to adding the square of the weights to
the loss with plain (non-momentum) SGD.
@ -111,24 +113,26 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
def from_config(cls, config):
"""Creates an optimizer from its config with WarmUp custom object."""
custom_objects = {"WarmUp": WarmUp}
return super().from_config(config, custom_objects=custom_objects)
return super(AdamWeightDecay, cls).from_config(config, custom_objects=custom_objects)
def _prepare_local(self, var_device, var_dtype, apply_state):
super()._prepare_local(var_device, var_dtype, apply_state)
apply_state["weight_decay_rate"] = tf.constant(self.weight_decay_rate, name="adam_weight_decay_rate")
super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype, apply_state)
apply_state[(var_device, var_dtype)]["weight_decay_rate"] = tf.constant(
self.weight_decay_rate, name="adam_weight_decay_rate"
)
def _decay_weights_op(self, var, learning_rate, apply_state):
do_decay = self._do_use_weight_decay(var.name)
if do_decay:
return var.assign_sub(
learning_rate * var * apply_state["weight_decay_rate"], use_locking=self._use_locking
learning_rate * var * apply_state[(var.device, var.dtype.base_dtype)]["weight_decay_rate"],
use_locking=self._use_locking,
)
return tf.no_op()
def apply_gradients(self, grads_and_vars, clip_norm, name=None):
def apply_gradients(self, grads_and_vars, name=None):
grads, tvars = list(zip(*grads_and_vars))
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=clip_norm)
return super().apply_gradients(zip(grads, tvars))
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars), name=name,)
def _get_lr(self, var_device, var_dtype, apply_state):
"""Retrieves the learning rate with the given state."""
@ -147,13 +151,13 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super()._resource_apply_dense(grad, var, **kwargs)
return super(AdamWeightDecay, self)._resource_apply_dense(grad, var, **kwargs)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super()._resource_apply_sparse(grad, var, indices, **kwargs)
return super(AdamWeightDecay, self)._resource_apply_sparse(grad, var, indices, **kwargs)
def get_config(self):
config = super().get_config()
@ -177,71 +181,65 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
return True
# Inspired from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
# Extracted from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
class GradientAccumulator(object):
"""Distribution strategies-aware gradient accumulation utility."""
"""Gradient accumulation utility.
When used with a distribution strategy, the accumulator should be called in a
replica context. Gradients will be accumulated locally on each replica and
without synchronization. Users should then call ``.gradients``, scale the
gradients if required, and pass the result to ``apply_gradients``.
"""
# We use the ON_READ synchronization policy so that no synchronization is
# performed on assignment. To get the value, we call .value() which returns the
# value on the current replica without synchronization.
def __init__(self):
"""Initializes the accumulator."""
self._gradients = []
self._accum_steps = tf.Variable(
initial_value=0, dtype=tf.int64, trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA
)
self._accum_steps = None
@property
def step(self):
"""Number of accumulated steps."""
if self._accum_steps is None:
self._accum_steps = tf.Variable(
tf.constant(0, dtype=tf.int64), trainable=False, synchronization=tf.VariableSynchronization.ON_READ,
)
return self._accum_steps.value()
@property
def gradients(self):
"""The accumulated gradients."""
return list(
gradient.value() if gradient is not None else gradient for gradient in self._get_replica_gradients()
)
"""The accumulated gradients on the current replica."""
if not self._gradients:
raise ValueError("The accumulator should be called first to initialize the gradients")
return list(gradient.value() for gradient in self._gradients)
def __call__(self, gradients):
"""Accumulates :obj:`gradients`."""
"""Accumulates :obj:`gradients` on the current replica."""
if not self._gradients:
_ = self.step # Create the step variable.
self._gradients.extend(
[
tf.Variable(tf.zeros_like(gradient), trainable=False) if gradient is not None else gradient
tf.Variable(
tf.zeros_like(gradient), trainable=False, synchronization=tf.VariableSynchronization.ON_READ,
)
for gradient in gradients
]
)
if len(gradients) != len(self._gradients):
raise ValueError("Expected %s gradients, but got %d" % (len(self._gradients), len(gradients)))
for accum_gradient, gradient in zip(self._get_replica_gradients(), gradients):
if accum_gradient is not None and gradient is not None:
accum_gradient.assign_add(gradient)
for accum_gradient, gradient in zip(self._gradients, gradients):
accum_gradient.assign_add(gradient)
self._accum_steps.assign_add(1)
def reset(self):
"""Resets the accumulated gradients."""
if self._gradients:
self._accum_steps.assign(0)
for gradient in self._get_replica_gradients():
if gradient is not None:
gradient.assign(tf.zeros_like(gradient))
def _get_replica_gradients(self):
if tf.distribute.has_strategy():
# In a replica context, we want to accumulate gradients on each replica
# without synchronization, so we directly assign the value of the
# current replica.
replica_context = tf.distribute.get_replica_context()
if replica_context is None or tf.distribute.get_strategy().num_replicas_in_sync == 1:
return self._gradients
return (
gradient.device_map.select_for_current_replica(gradient.values, replica_context)
for gradient in self._gradients
if gradient is not None
)
else:
return self._gradients
"""Resets the accumulated gradients on the current replica."""
if not self._gradients:
return
self._accum_steps.assign(0)
for gradient in self._gradients:
gradient.assign(tf.zeros_like(gradient))

View File

@ -6,7 +6,7 @@ import re
import shutil
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
@ -20,6 +20,7 @@ from tqdm.auto import tqdm, trange
from .data.data_collator import DataCollator, DefaultDataCollator
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput
from .training_args import TrainingArguments
@ -87,30 +88,6 @@ def torch_distributed_zero_first(local_rank: int):
torch.distributed.barrier()
class EvalPrediction(NamedTuple):
"""
Evaluation output (always contains labels), to be used
to compute metrics.
"""
predictions: np.ndarray
label_ids: np.ndarray
class PredictionOutput(NamedTuple):
predictions: np.ndarray
label_ids: Optional[np.ndarray]
metrics: Optional[Dict[str, float]]
class TrainOutput(NamedTuple):
global_step: int
training_loss: float
PREFIX_CHECKPOINT_DIR = "checkpoint"
class Trainer:
"""
Trainer is a simple but feature-complete training and eval loop for PyTorch,

View File

@ -0,0 +1,429 @@
"""Tensorflow trainer class."""
import logging
import math
import os
from typing import Callable, Dict, Optional
import numpy as np
import tensorflow as tf
from .modeling_tf_utils import TFPreTrainedModel, shape_list
from .optimization_tf import GradientAccumulator, create_optimizer
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput
from .training_args_tf import TFTrainingArguments
logger = logging.getLogger(__name__)
class TFTrainer:
model: TFPreTrainedModel
args: TFTrainingArguments
# something similar to a PT Dataset.
# This is just temporary before to have
# a framework-agnostic approach for datasets.
train_dataset: Optional[tf.data.Dataset]
eval_dataset: Optional[tf.data.Dataset]
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
prediction_loss_only: bool
def __init__(
self,
model: TFPreTrainedModel,
args: TFTrainingArguments,
train_dataset: Optional[tf.data.Dataset] = None,
eval_dataset: Optional[tf.data.Dataset] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
prediction_loss_only=False,
):
self.model = model
self.args = args
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.compute_metrics = compute_metrics
self.prediction_loss_only = prediction_loss_only
self.gradient_accumulator = GradientAccumulator()
self._setup_training()
def _setup_training(self) -> None:
"""
Setup the different steps to train a model:
- check if all the data are given
- create the proper strategy
- create the features
- prepare the model settings
"""
self._prepare_dataset()
with self.args.strategy.scope():
self._create_optimizer()
_ = self.optimizer.iterations
self._set_loss_and_metric()
self._create_checkpoint_manager()
self._create_summary_writer()
def _set_loss_and_metric(self) -> None:
"""
Create the training loss and metric with their name. Allowed names are those listed
in the Tensorflow documentation and those contained in the transformers library.
"""
try:
self.loss = tf.keras.losses.get(
{
"class_name": self.args.loss_name,
"config": {"from_logits": True, "reduction": tf.keras.losses.Reduction.NONE},
}
)
except TypeError:
self.loss = tf.keras.losses.get(
{"class_name": self.args.loss_name, "config": {"reduction": tf.keras.losses.Reduction.NONE}}
)
def _create_summary_writer(self) -> None:
"""
Create a summary writer to be able to read the logs in Tensorboard.
"""
self.writer = tf.summary.create_file_writer(self.args.logging_dir)
def _prepare_dataset(self) -> None:
"""
Prepare the training, validation and test data.
"""
if self.train_dataset is not None:
self.num_train_examples = self.train_dataset.reduce(tf.constant(0), lambda x, _: x + 1).numpy()
if self.args.max_steps > 0:
self.train_steps = self.args.max_steps
else:
self.train_steps: int = math.ceil(self.num_train_examples / self.args.train_batch_size)
self.train_dataset = (
self.train_dataset.cache()
.shuffle(self.num_train_examples)
.batch(self.args.train_batch_size)
.prefetch(tf.data.experimental.AUTOTUNE)
)
if self.args.max_steps > 0:
self.train_dataset = self.train_dataset.repeat(-1)
self.train_dataset = self.args.strategy.experimental_distribute_dataset(self.train_dataset)
else:
self.train_steps = 0
if self.eval_dataset is not None:
self.eval_dataset = (
self.eval_dataset.batch(self.args.eval_batch_size).cache().prefetch(tf.data.experimental.AUTOTUNE)
)
self.eval_dataset = self.args.strategy.experimental_distribute_dataset(self.eval_dataset)
def _create_optimizer(self) -> None:
"""
Create the training optimizer with its name. Allowed names are those listed
in the Tensorflow documentation and those contained in the transformers library.
"""
if self.args.optimizer_name == "adamw":
self.optimizer = create_optimizer(self.args.learning_rate, self.train_steps, self.args.warmup_steps)
else:
try:
self.optimizer = tf.keras.optimizers.get(
{
"class_name": self.args.optimizer_name,
"config": {"learning_rate": self.args.learning_rate, "epsilon": self.args.adam_epsilon},
}
)
except TypeError:
# This is for the case where the optimizer is not Adam-like such as SGD
self.optimizer = tf.keras.optimizers.get(
{"class_name": self.args.optimizer_name, "config": {"learning_rate": self.args.learning_rate}}
)
def _create_checkpoint_manager(self, max_to_keep: int = 5, load_model: bool = True) -> None:
"""
Create a checkpoint manager in order to be able to make the training
fault-tolerant.
Args:
max_to_keep: the maximum number of checkpoints to keep in the checkpoint path.
load_model: if we want to start the training from the latest checkpoint.
"""
ckpt = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, PREFIX_CHECKPOINT_DIR, max_to_keep=max_to_keep)
if load_model:
ckpt.restore(self.model.ckpt_manager.latest_checkpoint).expect_partial()
@tf.function
def _evaluate_steps(self, per_replica_features, per_replica_labels):
"""
One step evaluation across replica.
Args:
per_replica_features: the batched features.
per_replica_labels: the batched labels.
Returns:
The loss corresponding to the given batch.
"""
per_replica_loss, per_replica_logits = self.args.strategy.experimental_run_v2(
self._run_model, args=(per_replica_features, per_replica_labels, False)
)
try:
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=0)
except ValueError:
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, None)
return reduced_loss, per_replica_logits
def _prediction_loop(
self, dataset: tf.data.Dataset, description: str, prediction_loss_only: Optional[bool] = None
) -> PredictionOutput:
logger.info("***** Running %s *****", description)
logger.info(" Batch size = %d", self.args.eval_batch_size)
label_ids: np.ndarray = None
preds: np.ndarray = None
step: int = 1
for features, labels in dataset:
step = tf.convert_to_tensor(step, dtype=tf.int64)
loss, logits = self._evaluate_steps(features, labels)
loss = tf.reduce_mean(loss)
if not prediction_loss_only:
if self.args.n_gpu > 1:
for val in logits.values:
if preds is None:
preds = val.numpy()
else:
preds = np.append(preds, val.numpy(), axis=0)
for val in labels.values:
if label_ids is None:
label_ids = val.numpy()
else:
label_ids = np.append(label_ids, val.numpy(), axis=0)
else:
if preds is None:
preds = logits.numpy()
else:
preds = np.append(preds, logits.numpy(), axis=0)
if label_ids is None:
label_ids = labels.numpy()
else:
label_ids = np.append(label_ids, labels.numpy(), axis=0)
step += 1
if self.compute_metrics is not None and preds is not None and label_ids is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
else:
metrics = {}
metrics["loss"] = loss.numpy()
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def evaluate(
self, eval_dataset: Optional[tf.data.Dataset] = None, prediction_loss_only: Optional[bool] = None
) -> Dict[str, float]:
"""
Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
"""
if eval_dataset is None:
eval_dataset = self.eval_dataset
output = self._prediction_loop(eval_dataset, description="Evaluation")
return output.metrics
def train(self) -> None:
"""
Train method to train the model.
"""
if self.args.debug:
tf.summary.trace_on(graph=True, profiler=True)
self.gradient_accumulator.reset()
iterations = self.optimizer.iterations
if iterations.numpy() > 0:
logger.info("Start the training from the last checkpoint")
start_epoch = (iterations.numpy() // self.train_steps) + 1
else:
start_epoch = 1
tf.summary.experimental.set_step(iterations)
epochs = 1 if self.args.max_steps > 0 else self.args.num_train_epochs
logger.info("***** Running training *****")
logger.info(" Num examples = %d", self.num_train_examples)
logger.info(" Num Epochs = %d", epochs)
logger.info(" Total optimization steps = %d", self.train_steps)
for epoch in range(start_epoch, int(epochs + 1)):
for training_loss in self._training_steps():
step = iterations.numpy()
if self.args.debug:
with self.writer.as_default():
tf.summary.scalar("loss", training_loss, step=step)
if step == 1 and self.args.debug:
with self.writer.as_default():
tf.summary.trace_export(name="training", step=step, profiler_outdir=self.args.logging_dir)
if self.args.evaluate_during_training and step % self.args.eval_steps == 0:
logs = {}
results = self.evaluate()
for key, value in results.items():
eval_key = "eval_{}".format(key)
logs[eval_key] = value
if callable(self.optimizer.learning_rate):
logs["learning_rate"] = self.optimizer.learning_rate(step).numpy()
else:
logs["learning_rate"] = self.optimizer.learning_rate.numpy()
logger.info("Epoch {} Step {} Validation Metrics {}".format(epoch, step, logs))
with self.writer.as_default():
for k, v in logs.items():
tf.summary.scalar(k, v, step=step)
if step % self.args.logging_steps == 0:
logger.info("Epoch {} Step {} Train Loss {:.4f}".format(epoch, step, training_loss.numpy()))
if step % self.args.save_steps == 0:
ckpt_save_path = self.model.ckpt_manager.save()
logger.info("Saving checkpoint for step {} at {}".format(step, ckpt_save_path))
if step % self.train_steps == 0:
break
def _training_steps(self):
"""
Returns a generator over training steps (i.e. parameters update).
"""
for i, loss in enumerate(self._accumulate_next_gradients()):
if i % self.args.gradient_accumulation_steps == 0:
self._apply_gradients()
yield loss
@tf.function
def _apply_gradients(self):
"""Applies the gradients (cross-replica)."""
self.args.strategy.experimental_run_v2(self._step)
def _step(self):
"""Applies gradients and resets accumulation."""
gradient_scale = self.gradient_accumulator.step * self.args.strategy.num_replicas_in_sync
gradients = [
gradient / tf.cast(gradient_scale, gradient.dtype) for gradient in self.gradient_accumulator.gradients
]
gradients = [(tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients]
vars = self.model.trainable_variables
if self.args.mode == "token-classification":
vars = [var for var in self.model.trainable_variables if "pooler" not in var.name]
self.optimizer.apply_gradients(list(zip(gradients, vars)))
self.gradient_accumulator.reset()
def _accumulate_next_gradients(self):
"""Accumulates the gradients from the next element in dataset."""
iterator = iter(self.train_dataset)
@tf.function
def _accumulate_next():
per_replica_features, per_replica_labels = next(iterator)
return self._accumulate_gradients(per_replica_features, per_replica_labels)
while True:
try:
yield _accumulate_next()
except tf.errors.OutOfRangeError:
break
def _accumulate_gradients(self, per_replica_features, per_replica_labels):
"""Accumulates the gradients across all the replica."""
per_replica_loss = self.args.strategy.experimental_run_v2(
self._forward, args=(per_replica_features, per_replica_labels)
)
try:
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=0)
except ValueError:
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, None)
return reduced_loss
def _forward(self, features, labels):
"""Forwards a training example and accumulates the gradients."""
per_example_loss, _ = self._run_model(features, labels, True)
vars = self.model.trainable_variables
if self.args.mode == "token-classification":
vars = [var for var in self.model.trainable_variables if "pooler" not in var.name]
gradients = self.optimizer.get_gradients(per_example_loss, vars)
self.gradient_accumulator(gradients)
return per_example_loss
def _run_model(self, features, labels, training):
"""
Computes the loss of the given features and labels pair.
Args:
features: the batched features.
labels: the batched labels.
training: run the model in training mode or not
"""
if self.args.mode == "sequence-classification" or self.args.mode == "token-classification":
logits = self.model(features, training=training)[0]
else:
logits = self.model(features, training=training)
if self.args.mode == "token-classification":
active_loss = tf.reshape(labels, (-1,)) != -1
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
loss = self.loss(labels, reduced_logits)
else:
loss = self.loss(labels, logits)
loss += sum(self.model.losses) * (1.0 / self.args.n_gpu)
return loss, logits
def predict(self, test_dataset: tf.data.Dataset) -> PredictionOutput:
"""
Run prediction and return predictions and potential metrics.
Depending on the dataset and your use case, your test dataset may contain labels.
In that case, this method will also return metrics, like in evaluate().
Args:
test_dataset: something similar to a PT Dataset. This is just
temporary before to have a framework-agnostic approach for datasets.
"""
test_dataset = test_dataset.batch(self.args.eval_batch_size)
test_dataset = self.args.strategy.experimental_distribute_dataset(test_dataset)
return self._prediction_loop(test_dataset, description="Prediction")
def save_model(self) -> None:
"""
Save the pretrained model and create a Tensorflow saved model.
"""
logger.info("Saving model in {}".format(self.args.output_dir))
path = os.path.join(self.args.output_dir, "saved_model")
os.makedirs(path, exist_ok=True)
self.model.save_pretrained(self.args.output_dir)

View File

@ -0,0 +1,27 @@
from typing import Dict, NamedTuple, Optional
import numpy as np
class EvalPrediction(NamedTuple):
"""
Evaluation output (always contains labels), to be used
to compute metrics.
"""
predictions: np.ndarray
label_ids: np.ndarray
class PredictionOutput(NamedTuple):
predictions: np.ndarray
label_ids: Optional[np.ndarray]
metrics: Optional[Dict[str, float]]
class TrainOutput(NamedTuple):
global_step: int
training_loss: float
PREFIX_CHECKPOINT_DIR = "checkpoint"

View File

@ -0,0 +1,75 @@
import logging
from dataclasses import dataclass, field
from typing import Tuple
from .file_utils import cached_property, is_tf_available, tf_required
from .training_args import TrainingArguments
logger = logging.getLogger(__name__)
if is_tf_available():
import tensorflow as tf
@dataclass
class TFTrainingArguments(TrainingArguments):
optimizer_name: str = field(
default="adam",
metadata={
"help": 'Name of a Tensorflow optimizer among "adadelta, adagrad, adam, adamax, ftrl, nadam, rmsprop, sgd, adamw"'
},
)
mode: str = field(
default="sequence-classification",
metadata={"help": 'Type of task, one of "sequence-classification", "token-classification" '},
)
loss_name: str = field(
default="SparseCategoricalCrossentropy",
metadata={
"help": "Name of a Tensorflow loss. For the list see: https://www.tensorflow.org/api_docs/python/tf/keras/losses"
},
)
eval_steps: int = field(default=1000, metadata={"help": "Run an evaluation every X steps."})
debug: bool = field(
default=False, metadata={"help": "Activate the trace to record computation graphs and profiling information"}
)
@cached_property
@tf_required
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
logger.info("Tensorflow: setting up strategy")
gpus = tf.config.list_physical_devices("GPU")
if self.no_cuda:
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
else:
try:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
except ValueError:
tpu = None
if tpu:
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.experimental.TPUStrategy(tpu)
elif len(gpus) == 0:
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
elif len(gpus) > 1:
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
strategy = tf.distribute.MirroredStrategy(gpus)
else:
raise ValueError("Cannot find the proper strategy please check your environment properties.")
return strategy
@property
@tf_required
def strategy(self) -> "tf.distribute.Strategy":
return self._setup_strategy
@property
@tf_required
def n_gpu(self) -> int:
return self._setup_strategy.num_replicas_in_sync

View File

@ -36,14 +36,13 @@ class OptimizationFTest(unittest.TestCase):
def testGradientAccumulatorDistributionStrategy(self):
context._context = None
ops.enable_eager_execution_internal()
physical_devices = tf.config.experimental.list_physical_devices("CPU")
tf.config.experimental.set_virtual_device_configuration(
physical_devices[0],
[tf.config.experimental.VirtualDeviceConfiguration(), tf.config.experimental.VirtualDeviceConfiguration()],
)
devices = tf.config.experimental.list_logical_devices(device_type="CPU")
strategy = tf.distribute.MirroredStrategy(devices=[device.name for device in devices])
physical_devices = tf.config.list_physical_devices("CPU")
if len(physical_devices) == 1:
tf.config.set_logical_device_configuration(
physical_devices[0], [tf.config.LogicalDeviceConfiguration(), tf.config.LogicalDeviceConfiguration()]
)
devices = tf.config.list_logical_devices(device_type="CPU")
strategy = tf.distribute.MirroredStrategy(devices=devices[:2])
with strategy.scope():
accumulator = GradientAccumulator()
@ -55,13 +54,14 @@ class OptimizationFTest(unittest.TestCase):
accumulator([gradient])
def apply_on_replica():
optimizer.apply_gradients(list(zip(accumulator.gradients, [variable])), 1.0)
optimizer.apply_gradients(list(zip(accumulator.gradients, [variable])))
@tf.function
def accumulate(grad1, grad2):
with strategy.scope():
gradient_placeholder.values[0].assign(grad1)
gradient_placeholder.values[1].assign(grad2)
local_variables = strategy.experimental_local_results(gradient_placeholder)
local_variables[0].assign(grad1)
local_variables[1].assign(grad2)
strategy.experimental_run_v2(accumulate_on_replica, args=(gradient_placeholder,))
@tf.function
@ -69,15 +69,18 @@ class OptimizationFTest(unittest.TestCase):
with strategy.scope():
strategy.experimental_run_v2(apply_on_replica)
def _check_local_values(grad1, grad2):
values = strategy.experimental_local_results(accumulator._gradients[0])
self.assertListAlmostEqual(values[0].value(), grad1, tol=1e-2)
self.assertListAlmostEqual(values[1].value(), grad2, tol=1e-2)
accumulate([1.0, 2.0], [-1.0, 1.0])
accumulate([3.0, -1.0], [-1.0, -1.0])
accumulate([-2.0, 2.0], [3.0, -2.0])
self.assertEqual(accumulator.step, 3)
self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [2.0, 3.0], tol=1e-2)
self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [1.0, -2.0], tol=1e-2)
_check_local_values([2.0, 3.0], [1.0, -2.0])
apply_grad()
self.assertListAlmostEqual(variable.value().numpy().tolist(), [4.0, 3.0], tol=1e-2)
self.assertListAlmostEqual(variable.value(), [4.0, 3.0], tol=1e-2)
accumulator.reset()
self.assertEqual(accumulator.step, 0)
self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [0.0, 0.0], tol=1e-2)
self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [0.0, 0.0], tol=1e-2)
_check_local_values([0.0, 0.0], [0.0, 0.0])