mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
TF version of the trainer (#4017)
* First commit to add a TF version of the trainer. * Make the TF trainer closer to what looks the PT trainer * Refactoring common code between the PT and TF trainer into an util file. * Some bugfix + better similarity with the PT trainer * Add missing class in transformers init * Bugfix over prediction + use classification report instead of simple metrics * Fix name error * Fix optimization tests + style * Apply style * Several bugfix for multi-gpu training * Apply style * Apply style * Add glue example for the TF trainer * Several bugix + address the reviews * Fix on the TF training args file * Add a debug mode * Bugfix in utils_ner.py when segment_ids is None * Apply style * Apply style * Add TPU strategy * Fix selection strategy
This commit is contained in:
parent
25296b12aa
commit
aad50151f3
@ -1,628 +1,281 @@
|
||||
# coding=utf-8
|
||||
import collections
|
||||
import datetime
|
||||
import glob
|
||||
import math
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Fine-tuning the library models for named entity recognition."""
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from absl import app, flags, logging
|
||||
from fastprogress import master_bar, progress_bar
|
||||
from seqeval import metrics
|
||||
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
|
||||
|
||||
from transformers import (
|
||||
TF2_WEIGHTS_NAME,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
GradientAccumulator,
|
||||
PreTrainedTokenizer,
|
||||
EvalPrediction,
|
||||
HfArgumentParser,
|
||||
TFAutoModelForTokenClassification,
|
||||
create_optimizer,
|
||||
TFTrainer,
|
||||
TFTrainingArguments,
|
||||
)
|
||||
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
||||
from utils_ner import Split, TFNerDataset, get_labels
|
||||
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
flags.DEFINE_string(
|
||||
"data_dir", None, "The input data dir. Should contain the .conll files (or other data files) for the task."
|
||||
)
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
flags.DEFINE_string(
|
||||
"model_name_or_path", None, "Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
|
||||
flags.DEFINE_string("output_dir", None, "The output directory where the model checkpoints will be written.")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"labels", "", "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."
|
||||
)
|
||||
|
||||
flags.DEFINE_string("config_name", None, "Pretrained config name or path if not the same as model_name")
|
||||
|
||||
flags.DEFINE_string("tokenizer_name", None, "Pretrained tokenizer name or path if not the same as model_name")
|
||||
|
||||
flags.DEFINE_string("cache_dir", None, "Where do you want to store the pre-trained models downloaded from s3")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"max_seq_length",
|
||||
128,
|
||||
"The maximum total input sentence length after tokenization. "
|
||||
"Sequences longer than this will be truncated, sequences shorter "
|
||||
"will be padded.",
|
||||
)
|
||||
|
||||
flags.DEFINE_string(
|
||||
"tpu",
|
||||
None,
|
||||
"The Cloud TPU to use for training. This should be either the name "
|
||||
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
||||
"url.",
|
||||
)
|
||||
|
||||
flags.DEFINE_integer("num_tpu_cores", 8, "Total number of TPU cores to use.")
|
||||
|
||||
flags.DEFINE_boolean("do_train", False, "Whether to run training.")
|
||||
|
||||
flags.DEFINE_boolean("do_eval", False, "Whether to run eval on the dev set.")
|
||||
|
||||
flags.DEFINE_boolean("do_predict", False, "Whether to run predictions on the test set.")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"evaluate_during_training", False, "Whether to run evaluation during training at each logging step."
|
||||
)
|
||||
|
||||
flags.DEFINE_boolean("do_lower_case", False, "Set this flag if you are using an uncased model.")
|
||||
|
||||
flags.DEFINE_integer("per_device_train_batch_size", 8, "Batch size per GPU/CPU/TPU for training.")
|
||||
|
||||
flags.DEFINE_integer("per_device_eval_batch_size", 8, "Batch size per GPU/CPU/TPU for evaluation.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"gradient_accumulation_steps", 1, "Number of updates steps to accumulate before performing a backward/update pass."
|
||||
)
|
||||
|
||||
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
|
||||
|
||||
flags.DEFINE_float("weight_decay", 0.0, "Weight decay if we apply some.")
|
||||
|
||||
flags.DEFINE_float("adam_epsilon", 1e-8, "Epsilon for Adam optimizer.")
|
||||
|
||||
flags.DEFINE_float("max_grad_norm", 1.0, "Max gradient norm.")
|
||||
|
||||
flags.DEFINE_integer("num_train_epochs", 3, "Total number of training epochs to perform.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"max_steps", -1, "If > 0: set total number of training steps to perform. Override num_train_epochs."
|
||||
)
|
||||
|
||||
flags.DEFINE_integer("warmup_steps", 0, "Linear warmup over warmup_steps.")
|
||||
|
||||
flags.DEFINE_integer("logging_steps", 50, "Log every X updates steps.")
|
||||
|
||||
flags.DEFINE_integer("save_steps", 50, "Save checkpoint every X updates steps.")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"eval_all_checkpoints",
|
||||
False,
|
||||
"Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
|
||||
flags.DEFINE_boolean("no_cuda", False, "Avoid using CUDA even if it is available")
|
||||
|
||||
flags.DEFINE_boolean("overwrite_output_dir", False, "Overwrite the content of the output directory")
|
||||
|
||||
flags.DEFINE_boolean("overwrite_cache", False, "Overwrite the cached training and evaluation sets")
|
||||
|
||||
flags.DEFINE_integer("seed", 42, "random seed for initialization")
|
||||
|
||||
flags.DEFINE_boolean("fp16", False, "Whether to use 16-bit (mixed) precision instead of 32-bit")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"gpus",
|
||||
"0",
|
||||
"Comma separated list of gpus devices. If only one, switch to single "
|
||||
"gpu strategy, if None takes all the gpus available.",
|
||||
)
|
||||
|
||||
|
||||
def train(
|
||||
args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id
|
||||
):
|
||||
if args["max_steps"] > 0:
|
||||
num_train_steps = args["max_steps"] * args["gradient_accumulation_steps"]
|
||||
args["num_train_epochs"] = 1
|
||||
else:
|
||||
num_train_steps = (
|
||||
math.ceil(num_train_examples / train_batch_size)
|
||||
// args["gradient_accumulation_steps"]
|
||||
* args["num_train_epochs"]
|
||||
)
|
||||
|
||||
writer = tf.summary.create_file_writer("/tmp/mylogs")
|
||||
|
||||
with strategy.scope():
|
||||
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||
)
|
||||
optimizer = create_optimizer(args["learning_rate"], num_train_steps, args["warmup_steps"])
|
||||
|
||||
if args["fp16"]:
|
||||
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, "dynamic")
|
||||
|
||||
loss_metric = tf.keras.metrics.Mean(name="loss", dtype=tf.float32)
|
||||
gradient_accumulator = GradientAccumulator()
|
||||
|
||||
logging.info("***** Running training *****")
|
||||
logging.info(" Num examples = %d", num_train_examples)
|
||||
logging.info(" Num Epochs = %d", args["num_train_epochs"])
|
||||
logging.info(" Instantaneous batch size per device = %d", args["per_device_train_batch_size"])
|
||||
logging.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
train_batch_size * args["gradient_accumulation_steps"],
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
logging.info(" Gradient Accumulation steps = %d", args["gradient_accumulation_steps"])
|
||||
logging.info(" Total training steps = %d", num_train_steps)
|
||||
|
||||
model.summary()
|
||||
|
||||
@tf.function
|
||||
def apply_gradients():
|
||||
grads_and_vars = []
|
||||
|
||||
for gradient, variable in zip(gradient_accumulator.gradients, model.trainable_variables):
|
||||
if gradient is not None:
|
||||
scaled_gradient = gradient / (args["n_device"] * args["gradient_accumulation_steps"])
|
||||
grads_and_vars.append((scaled_gradient, variable))
|
||||
else:
|
||||
grads_and_vars.append((gradient, variable))
|
||||
|
||||
optimizer.apply_gradients(grads_and_vars, args["max_grad_norm"])
|
||||
gradient_accumulator.reset()
|
||||
|
||||
@tf.function
|
||||
def train_step(train_features, train_labels):
|
||||
def step_fn(train_features, train_labels):
|
||||
inputs = {"attention_mask": train_features["attention_mask"], "training": True}
|
||||
|
||||
if "token_type_ids" in train_features:
|
||||
inputs["token_type_ids"] = train_features["token_type_ids"]
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
logits = model(train_features["input_ids"], **inputs)[0]
|
||||
active_loss = tf.reshape(train_labels, (-1,)) != pad_token_label_id
|
||||
active_logits = tf.boolean_mask(tf.reshape(logits, (-1, len(labels))), active_loss)
|
||||
active_labels = tf.boolean_mask(tf.reshape(train_labels, (-1,)), active_loss)
|
||||
cross_entropy = loss_fct(active_labels, active_logits)
|
||||
loss = tf.reduce_sum(cross_entropy) * (1.0 / train_batch_size)
|
||||
grads = tape.gradient(loss, model.trainable_variables)
|
||||
|
||||
gradient_accumulator(grads)
|
||||
|
||||
return cross_entropy
|
||||
|
||||
per_example_losses = strategy.experimental_run_v2(step_fn, args=(train_features, train_labels))
|
||||
mean_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_example_losses, axis=0)
|
||||
|
||||
return mean_loss
|
||||
|
||||
current_time = datetime.datetime.now()
|
||||
train_iterator = master_bar(range(args["num_train_epochs"]))
|
||||
global_step = 0
|
||||
logging_loss = 0.0
|
||||
|
||||
for epoch in train_iterator:
|
||||
epoch_iterator = progress_bar(
|
||||
train_dataset, total=num_train_steps, parent=train_iterator, display=args["n_device"] > 1
|
||||
)
|
||||
step = 1
|
||||
|
||||
with strategy.scope():
|
||||
for train_features, train_labels in epoch_iterator:
|
||||
loss = train_step(train_features, train_labels)
|
||||
|
||||
if step % args["gradient_accumulation_steps"] == 0:
|
||||
strategy.experimental_run_v2(apply_gradients)
|
||||
|
||||
loss_metric(loss)
|
||||
|
||||
global_step += 1
|
||||
|
||||
if args["logging_steps"] > 0 and global_step % args["logging_steps"] == 0:
|
||||
# Log metrics
|
||||
if (
|
||||
args["n_device"] == 1 and args["evaluate_during_training"]
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
y_true, y_pred, eval_loss = evaluate(
|
||||
args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
|
||||
)
|
||||
report = metrics.classification_report(y_true, y_pred, digits=4)
|
||||
|
||||
logging.info("Eval at step " + str(global_step) + "\n" + report)
|
||||
logging.info("eval_loss: " + str(eval_loss))
|
||||
|
||||
precision = metrics.precision_score(y_true, y_pred)
|
||||
recall = metrics.recall_score(y_true, y_pred)
|
||||
f1 = metrics.f1_score(y_true, y_pred)
|
||||
|
||||
with writer.as_default():
|
||||
tf.summary.scalar("eval_loss", eval_loss, global_step)
|
||||
tf.summary.scalar("precision", precision, global_step)
|
||||
tf.summary.scalar("recall", recall, global_step)
|
||||
tf.summary.scalar("f1", f1, global_step)
|
||||
|
||||
lr = optimizer.learning_rate
|
||||
learning_rate = lr(step)
|
||||
|
||||
with writer.as_default():
|
||||
tf.summary.scalar("lr", learning_rate, global_step)
|
||||
tf.summary.scalar(
|
||||
"loss", (loss_metric.result() - logging_loss) / args["logging_steps"], global_step
|
||||
)
|
||||
|
||||
logging_loss = loss_metric.result()
|
||||
|
||||
with writer.as_default():
|
||||
tf.summary.scalar("loss", loss_metric.result(), step=step)
|
||||
|
||||
if args["save_steps"] > 0 and global_step % args["save_steps"] == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args["output_dir"], "checkpoint-{}".format(global_step))
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
model.save_pretrained(output_dir)
|
||||
logging.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
train_iterator.child.comment = f"loss : {loss_metric.result()}"
|
||||
step += 1
|
||||
|
||||
train_iterator.write(f"loss epoch {epoch + 1}: {loss_metric.result()}")
|
||||
|
||||
loss_metric.reset_states()
|
||||
|
||||
logging.info(" Training took time = {}".format(datetime.datetime.now() - current_time))
|
||||
|
||||
|
||||
def evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode):
|
||||
eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
|
||||
eval_dataset, size = load_and_cache_examples(
|
||||
args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode=mode
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
eval_dataset = strategy.experimental_distribute_dataset(eval_dataset)
|
||||
preds = None
|
||||
num_eval_steps = math.ceil(size / eval_batch_size)
|
||||
master = master_bar(range(1))
|
||||
eval_iterator = progress_bar(eval_dataset, total=num_eval_steps, parent=master, display=args["n_device"] > 1)
|
||||
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
|
||||
loss = 0.0
|
||||
|
||||
logging.info("***** Running evaluation *****")
|
||||
logging.info(" Num examples = %d", size)
|
||||
logging.info(" Batch size = %d", eval_batch_size)
|
||||
|
||||
for eval_features, eval_labels in eval_iterator:
|
||||
inputs = {"attention_mask": eval_features["attention_mask"], "training": False}
|
||||
|
||||
if "token_type_ids" in eval_features:
|
||||
inputs["token_type_ids"] = eval_features["token_type_ids"]
|
||||
|
||||
with strategy.scope():
|
||||
logits = model(eval_features["input_ids"], **inputs)[0]
|
||||
active_loss = tf.reshape(eval_labels, (-1,)) != pad_token_label_id
|
||||
active_logits = tf.boolean_mask(tf.reshape(logits, (-1, len(labels))), active_loss)
|
||||
active_labels = tf.boolean_mask(tf.reshape(eval_labels, (-1,)), active_loss)
|
||||
cross_entropy = loss_fct(active_labels, active_logits)
|
||||
loss += tf.reduce_sum(cross_entropy) * (1.0 / eval_batch_size)
|
||||
|
||||
if preds is None:
|
||||
preds = logits.numpy()
|
||||
label_ids = eval_labels.numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.numpy(), axis=0)
|
||||
label_ids = np.append(label_ids, eval_labels.numpy(), axis=0)
|
||||
|
||||
preds = np.argmax(preds, axis=2)
|
||||
y_pred = [[] for _ in range(label_ids.shape[0])]
|
||||
y_true = [[] for _ in range(label_ids.shape[0])]
|
||||
loss = loss / num_eval_steps
|
||||
|
||||
for i in range(label_ids.shape[0]):
|
||||
for j in range(label_ids.shape[1]):
|
||||
if label_ids[i, j] != pad_token_label_id:
|
||||
y_pred[i].append(labels[preds[i, j] - 1])
|
||||
y_true[i].append(labels[label_ids[i, j] - 1])
|
||||
|
||||
return y_true, y_pred, loss.numpy()
|
||||
|
||||
|
||||
def load_cache(cached_file, tokenizer: PreTrainedTokenizer, max_seq_length):
|
||||
name_to_features = {
|
||||
"input_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
|
||||
"attention_mask": tf.io.FixedLenFeature([max_seq_length], tf.int64),
|
||||
"label_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
|
||||
}
|
||||
# TODO Find a cleaner way to do this.
|
||||
if "token_type_ids" in tokenizer.model_input_names:
|
||||
name_to_features["token_type_ids"] = tf.io.FixedLenFeature([max_seq_length], tf.int64)
|
||||
|
||||
def _decode_record(record):
|
||||
example = tf.io.parse_single_example(record, name_to_features)
|
||||
features = {}
|
||||
features["input_ids"] = example["input_ids"]
|
||||
features["attention_mask"] = example["attention_mask"]
|
||||
if "token_type_ids" in example:
|
||||
features["token_type_ids"] = example["token_type_ids"]
|
||||
|
||||
return features, example["label_ids"]
|
||||
|
||||
d = tf.data.TFRecordDataset(cached_file)
|
||||
d = d.map(_decode_record, num_parallel_calls=4)
|
||||
count = d.reduce(0, lambda x, _: x + 1)
|
||||
|
||||
return d, count.numpy()
|
||||
|
||||
|
||||
def save_cache(features, cached_features_file):
|
||||
writer = tf.io.TFRecordWriter(cached_features_file)
|
||||
|
||||
for (ex_index, feature) in enumerate(features):
|
||||
if ex_index % 5000 == 0:
|
||||
logging.info("Writing example %d of %d" % (ex_index, len(features)))
|
||||
|
||||
def create_int_feature(values):
|
||||
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
||||
return f
|
||||
|
||||
record_feature = collections.OrderedDict()
|
||||
record_feature["input_ids"] = create_int_feature(feature.input_ids)
|
||||
record_feature["attention_mask"] = create_int_feature(feature.attention_mask)
|
||||
if feature.token_type_ids is not None:
|
||||
record_feature["token_type_ids"] = create_int_feature(feature.token_type_ids)
|
||||
record_feature["label_ids"] = create_int_feature(feature.label_ids)
|
||||
|
||||
tf_example = tf.train.Example(features=tf.train.Features(feature=record_feature))
|
||||
|
||||
writer.write(tf_example.SerializeToString())
|
||||
|
||||
writer.close()
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, batch_size, mode):
|
||||
drop_remainder = True if args["tpu"] or mode == "train" else False
|
||||
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(
|
||||
args["data_dir"],
|
||||
"cached_{}_{}_{}.tf_record".format(mode, tokenizer.__class__.__name__, str(args["max_seq_length"])),
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
use_fast: bool = field(default=False, metadata={"help": "Set this flag to use fast tokenization."})
|
||||
# If you want to tweak more attributes on your tokenizer, you should do it in a distinct script,
|
||||
# or just modify its tokenizer_config.json.
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args["overwrite_cache"]:
|
||||
logging.info("Loading features from cached file %s", cached_features_file)
|
||||
dataset, size = load_cache(cached_features_file, tokenizer, args["max_seq_length"])
|
||||
else:
|
||||
logging.info("Creating features from dataset file at %s", args["data_dir"])
|
||||
examples = read_examples_from_file(args["data_dir"], mode)
|
||||
features = convert_examples_to_features(
|
||||
examples,
|
||||
labels,
|
||||
args["max_seq_length"],
|
||||
tokenizer,
|
||||
cls_token_at_end=bool(args["model_type"] in ["xlnet"]),
|
||||
# xlnet has a cls token at the end
|
||||
cls_token=tokenizer.cls_token,
|
||||
cls_token_segment_id=2 if args["model_type"] in ["xlnet"] else 0,
|
||||
sep_token=tokenizer.sep_token,
|
||||
sep_token_extra=bool(args["model_type"] in ["roberta"]),
|
||||
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
|
||||
pad_on_left=bool(args["model_type"] in ["xlnet"]),
|
||||
# pad on the left for xlnet
|
||||
pad_token=tokenizer.pad_token_id,
|
||||
pad_token_segment_id=tokenizer.pad_token_type_id,
|
||||
pad_token_label_id=pad_token_label_id,
|
||||
)
|
||||
logging.info("Saving features into cached file %s", cached_features_file)
|
||||
save_cache(features, cached_features_file)
|
||||
dataset, size = load_cache(cached_features_file, tokenizer, args["max_seq_length"])
|
||||
|
||||
if mode == "train":
|
||||
dataset = dataset.repeat()
|
||||
dataset = dataset.shuffle(buffer_size=8192, seed=args["seed"])
|
||||
|
||||
dataset = dataset.batch(batch_size, drop_remainder)
|
||||
dataset = dataset.prefetch(buffer_size=batch_size)
|
||||
|
||||
return dataset, size
|
||||
|
||||
|
||||
def main(_):
|
||||
logging.set_verbosity(logging.INFO)
|
||||
args = flags.FLAGS.flag_values_dict()
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
data_dir: str = field(
|
||||
metadata={"help": "The input data dir. Should contain the .txt files for a CoNLL-2003-formatted task."}
|
||||
)
|
||||
labels: Optional[str] = field(
|
||||
metadata={"help": "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."}
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments))
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
if (
|
||||
os.path.exists(args["output_dir"])
|
||||
and os.listdir(args["output_dir"])
|
||||
and args["do_train"]
|
||||
and not args["overwrite_output_dir"]
|
||||
os.path.exists(training_args.output_dir)
|
||||
and os.listdir(training_args.output_dir)
|
||||
and training_args.do_train
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args["output_dir"]
|
||||
)
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
if args["fp16"]:
|
||||
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
|
||||
|
||||
if args["tpu"]:
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args["tpu"])
|
||||
tf.config.experimental_connect_to_cluster(resolver)
|
||||
tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||
strategy = tf.distribute.experimental.TPUStrategy(resolver)
|
||||
args["n_device"] = args["num_tpu_cores"]
|
||||
elif len(args["gpus"].split(",")) > 1:
|
||||
args["n_device"] = len([f"/gpu:{gpu}" for gpu in args["gpus"].split(",")])
|
||||
strategy = tf.distribute.MirroredStrategy(devices=[f"/gpu:{gpu}" for gpu in args["gpus"].split(",")])
|
||||
elif args["no_cuda"]:
|
||||
args["n_device"] = 1
|
||||
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
|
||||
else:
|
||||
args["n_device"] = len(args["gpus"].split(","))
|
||||
strategy = tf.distribute.OneDeviceStrategy(device="/gpu:" + args["gpus"].split(",")[0])
|
||||
|
||||
logging.warning(
|
||||
"n_device: %s, distributed training: %s, 16-bits training: %s",
|
||||
args["n_device"],
|
||||
bool(args["n_device"] > 1),
|
||||
args["fp16"],
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(
|
||||
"n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
training_args.n_gpu,
|
||||
bool(training_args.n_gpu > 1),
|
||||
training_args.fp16,
|
||||
)
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
|
||||
labels = get_labels(args["labels"])
|
||||
# Prepare Token Classification task
|
||||
labels = get_labels(data_args.labels)
|
||||
label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
|
||||
num_labels = len(labels)
|
||||
pad_token_label_id = -1
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
#
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
args["config_name"] if args["config_name"] else args["model_name_or_path"],
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
cache_dir=args["cache_dir"],
|
||||
id2label=label_map,
|
||||
label2id={label: i for i, label in enumerate(labels)},
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=model_args.use_fast,
|
||||
)
|
||||
|
||||
logging.info("Training/evaluation parameters %s", args)
|
||||
args["model_type"] = config.model_type
|
||||
with training_args.strategy.scope():
|
||||
model = TFAutoModelForTokenClassification.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_pt=bool(".bin" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
# Get datasets
|
||||
train_dataset = (
|
||||
TFNerDataset(
|
||||
data_dir=data_args.data_dir,
|
||||
tokenizer=tokenizer,
|
||||
labels=labels,
|
||||
model_type=config.model_type,
|
||||
max_seq_length=data_args.max_seq_length,
|
||||
overwrite_cache=data_args.overwrite_cache,
|
||||
mode=Split.train,
|
||||
)
|
||||
if training_args.do_train
|
||||
else None
|
||||
)
|
||||
eval_dataset = (
|
||||
TFNerDataset(
|
||||
data_dir=data_args.data_dir,
|
||||
tokenizer=tokenizer,
|
||||
labels=labels,
|
||||
model_type=config.model_type,
|
||||
max_seq_length=data_args.max_seq_length,
|
||||
overwrite_cache=data_args.overwrite_cache,
|
||||
mode=Split.dev,
|
||||
)
|
||||
if training_args.do_eval
|
||||
else None
|
||||
)
|
||||
|
||||
def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> Tuple[List[int], List[int]]:
|
||||
preds = np.argmax(predictions, axis=2)
|
||||
batch_size, seq_len = preds.shape
|
||||
out_label_list = [[] for _ in range(batch_size)]
|
||||
preds_list = [[] for _ in range(batch_size)]
|
||||
|
||||
for i in range(batch_size):
|
||||
for j in range(seq_len):
|
||||
if label_ids[i, j] != -1:
|
||||
out_label_list[i].append(label_map[label_ids[i][j]])
|
||||
preds_list[i].append(label_map[preds[i][j]])
|
||||
|
||||
return preds_list, out_label_list
|
||||
|
||||
def compute_metrics(p: EvalPrediction) -> Dict:
|
||||
preds_list, out_label_list = align_predictions(p.predictions, p.label_ids)
|
||||
|
||||
return {
|
||||
"precision": precision_score(out_label_list, preds_list),
|
||||
"recall": recall_score(out_label_list, preds_list),
|
||||
"f1": f1_score(out_label_list, preds_list),
|
||||
}
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = TFTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset.get_dataset() if train_dataset else None,
|
||||
eval_dataset=eval_dataset.get_dataset() if eval_dataset else None,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
# Training
|
||||
if args["do_train"]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args["tokenizer_name"] if args["tokenizer_name"] else args["model_name_or_path"],
|
||||
do_lower_case=args["do_lower_case"],
|
||||
cache_dir=args["cache_dir"],
|
||||
)
|
||||
|
||||
with strategy.scope():
|
||||
model = TFAutoModelForTokenClassification.from_pretrained(
|
||||
args["model_name_or_path"],
|
||||
from_pt=bool(".bin" in args["model_name_or_path"]),
|
||||
config=config,
|
||||
cache_dir=args["cache_dir"],
|
||||
)
|
||||
|
||||
train_batch_size = args["per_device_train_batch_size"] * args["n_device"]
|
||||
train_dataset, num_train_examples = load_and_cache_examples(
|
||||
args, tokenizer, labels, pad_token_label_id, train_batch_size, mode="train"
|
||||
)
|
||||
train_dataset = strategy.experimental_distribute_dataset(train_dataset)
|
||||
train(
|
||||
args,
|
||||
strategy,
|
||||
train_dataset,
|
||||
tokenizer,
|
||||
model,
|
||||
num_train_examples,
|
||||
labels,
|
||||
train_batch_size,
|
||||
pad_token_label_id,
|
||||
)
|
||||
|
||||
os.makedirs(args["output_dir"], exist_ok=True)
|
||||
|
||||
logging.info("Saving model to %s", args["output_dir"])
|
||||
|
||||
model.save_pretrained(args["output_dir"])
|
||||
tokenizer.save_pretrained(args["output_dir"])
|
||||
if training_args.do_train:
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
|
||||
# Evaluation
|
||||
if args["do_eval"]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
||||
checkpoints = []
|
||||
results = []
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
if args["eval_all_checkpoints"]:
|
||||
checkpoints = list(
|
||||
os.path.dirname(c)
|
||||
for c in sorted(
|
||||
glob.glob(args["output_dir"] + "/**/" + TF2_WEIGHTS_NAME, recursive=True),
|
||||
key=lambda f: int("".join(filter(str.isdigit, f)) or -1),
|
||||
)
|
||||
)
|
||||
result = trainer.evaluate()
|
||||
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
|
||||
|
||||
logging.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
|
||||
if len(checkpoints) == 0:
|
||||
checkpoints.append(args["output_dir"])
|
||||
for key, value in result.items():
|
||||
logger.info(" %s = %s", key, value)
|
||||
writer.write("%s = %s\n" % (key, value))
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"
|
||||
results.update(result)
|
||||
|
||||
with strategy.scope():
|
||||
model = TFAutoModelForTokenClassification.from_pretrained(checkpoint)
|
||||
|
||||
y_true, y_pred, eval_loss = evaluate(
|
||||
args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
|
||||
)
|
||||
report = metrics.classification_report(y_true, y_pred, digits=4)
|
||||
|
||||
if global_step:
|
||||
results.append({global_step + "_report": report, global_step + "_loss": eval_loss})
|
||||
|
||||
output_eval_file = os.path.join(args["output_dir"], "eval_results.txt")
|
||||
|
||||
with tf.io.gfile.GFile(output_eval_file, "w") as writer:
|
||||
for res in results:
|
||||
for key, val in res.items():
|
||||
if "loss" in key:
|
||||
logging.info(key + " = " + str(val))
|
||||
writer.write(key + " = " + str(val))
|
||||
writer.write("\n")
|
||||
else:
|
||||
logging.info(key)
|
||||
logging.info("\n" + report)
|
||||
writer.write(key + "\n")
|
||||
writer.write(report)
|
||||
writer.write("\n")
|
||||
|
||||
if args["do_predict"]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
||||
model = TFAutoModelForTokenClassification.from_pretrained(args["output_dir"])
|
||||
eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
|
||||
predict_dataset, _ = load_and_cache_examples(
|
||||
args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test"
|
||||
# Predict
|
||||
if training_args.do_predict:
|
||||
test_dataset = TFNerDataset(
|
||||
data_dir=data_args.data_dir,
|
||||
tokenizer=tokenizer,
|
||||
labels=labels,
|
||||
model_type=config.model_type,
|
||||
max_seq_length=data_args.max_seq_length,
|
||||
overwrite_cache=data_args.overwrite_cache,
|
||||
mode=Split.test,
|
||||
)
|
||||
y_true, y_pred, pred_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="test")
|
||||
output_test_results_file = os.path.join(args["output_dir"], "test_results.txt")
|
||||
output_test_predictions_file = os.path.join(args["output_dir"], "test_predictions.txt")
|
||||
report = metrics.classification_report(y_true, y_pred, digits=4)
|
||||
|
||||
with tf.io.gfile.GFile(output_test_results_file, "w") as writer:
|
||||
report = metrics.classification_report(y_true, y_pred, digits=4)
|
||||
predictions, label_ids, metrics = trainer.predict(test_dataset.get_dataset())
|
||||
preds_list, labels_list = align_predictions(predictions, label_ids)
|
||||
report = classification_report(labels_list, preds_list)
|
||||
|
||||
logging.info("\n" + report)
|
||||
logger.info("\n%s", report)
|
||||
|
||||
writer.write(report)
|
||||
writer.write("\n\nloss = " + str(pred_loss))
|
||||
output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt")
|
||||
|
||||
with tf.io.gfile.GFile(output_test_predictions_file, "w") as writer:
|
||||
with tf.io.gfile.GFile(os.path.join(args["data_dir"], "test.txt"), "r") as f:
|
||||
with open(output_test_results_file, "w") as writer:
|
||||
writer.write("%s\n" % report)
|
||||
|
||||
# Save predictions
|
||||
output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
|
||||
|
||||
with open(output_test_predictions_file, "w") as writer:
|
||||
with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f:
|
||||
example_id = 0
|
||||
|
||||
for line in f:
|
||||
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
||||
writer.write(line)
|
||||
|
||||
if not y_pred[example_id]:
|
||||
if not preds_list[example_id]:
|
||||
example_id += 1
|
||||
elif y_pred[example_id]:
|
||||
output_line = line.split()[0] + " " + y_pred[example_id].pop(0) + "\n"
|
||||
elif preds_list[example_id]:
|
||||
output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n"
|
||||
|
||||
writer.write(output_line)
|
||||
else:
|
||||
logging.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])
|
||||
logger.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
flags.mark_flag_as_required("data_dir")
|
||||
flags.mark_flag_as_required("output_dir")
|
||||
flags.mark_flag_as_required("model_name_or_path")
|
||||
app.run(main)
|
||||
main()
|
||||
|
@ -22,11 +22,7 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
from transformers import PreTrainedTokenizer, torch_distributed_zero_first
|
||||
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -68,70 +64,170 @@ class Split(Enum):
|
||||
test = "test"
|
||||
|
||||
|
||||
class NerDataset(Dataset):
|
||||
"""
|
||||
This will be superseded by a framework-agnostic approach
|
||||
soon.
|
||||
"""
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from transformers import torch_distributed_zero_first
|
||||
|
||||
features: List[InputFeatures]
|
||||
pad_token_label_id: int = nn.CrossEntropyLoss().ignore_index
|
||||
# Use cross entropy ignore_index as padding label id so that only
|
||||
# real label ids contribute to the loss later.
|
||||
class NerDataset(Dataset):
|
||||
"""
|
||||
This will be superseded by a framework-agnostic approach
|
||||
soon.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
labels: List[str],
|
||||
model_type: str,
|
||||
max_seq_length: Optional[int] = None,
|
||||
overwrite_cache=False,
|
||||
mode: Split = Split.train,
|
||||
local_rank=-1,
|
||||
):
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(
|
||||
data_dir, "cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length)),
|
||||
)
|
||||
features: List[InputFeatures]
|
||||
pad_token_label_id: int = nn.CrossEntropyLoss().ignore_index
|
||||
# Use cross entropy ignore_index as padding label id so that only
|
||||
# real label ids contribute to the loss later.
|
||||
|
||||
with torch_distributed_zero_first(local_rank):
|
||||
# Make sure only the first process in distributed training processes the dataset,
|
||||
# and the others will use the cache.
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
labels: List[str],
|
||||
model_type: str,
|
||||
max_seq_length: Optional[int] = None,
|
||||
overwrite_cache=False,
|
||||
mode: Split = Split.train,
|
||||
local_rank=-1,
|
||||
):
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(
|
||||
data_dir, "cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length)),
|
||||
)
|
||||
|
||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||
logger.info(f"Loading features from cached file {cached_features_file}")
|
||||
self.features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info(f"Creating features from dataset file at {data_dir}")
|
||||
examples = read_examples_from_file(data_dir, mode)
|
||||
# TODO clean up all this to leverage built-in features of tokenizers
|
||||
self.features = convert_examples_to_features(
|
||||
examples,
|
||||
labels,
|
||||
max_seq_length,
|
||||
tokenizer,
|
||||
cls_token_at_end=bool(model_type in ["xlnet"]),
|
||||
# xlnet has a cls token at the end
|
||||
cls_token=tokenizer.cls_token,
|
||||
cls_token_segment_id=2 if model_type in ["xlnet"] else 0,
|
||||
sep_token=tokenizer.sep_token,
|
||||
sep_token_extra=bool(model_type in ["roberta"]),
|
||||
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
|
||||
pad_on_left=bool(tokenizer.padding_side == "left"),
|
||||
pad_token=tokenizer.pad_token_id,
|
||||
pad_token_segment_id=tokenizer.pad_token_type_id,
|
||||
pad_token_label_id=self.pad_token_label_id,
|
||||
with torch_distributed_zero_first(local_rank):
|
||||
# Make sure only the first process in distributed training processes the dataset,
|
||||
# and the others will use the cache.
|
||||
|
||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||
logger.info(f"Loading features from cached file {cached_features_file}")
|
||||
self.features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info(f"Creating features from dataset file at {data_dir}")
|
||||
examples = read_examples_from_file(data_dir, mode)
|
||||
# TODO clean up all this to leverage built-in features of tokenizers
|
||||
self.features = convert_examples_to_features(
|
||||
examples,
|
||||
labels,
|
||||
max_seq_length,
|
||||
tokenizer,
|
||||
cls_token_at_end=bool(model_type in ["xlnet"]),
|
||||
# xlnet has a cls token at the end
|
||||
cls_token=tokenizer.cls_token,
|
||||
cls_token_segment_id=2 if model_type in ["xlnet"] else 0,
|
||||
sep_token=tokenizer.sep_token,
|
||||
sep_token_extra=bool(model_type in ["roberta"]),
|
||||
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
|
||||
pad_on_left=bool(tokenizer.padding_side == "left"),
|
||||
pad_token=tokenizer.pad_token_id,
|
||||
pad_token_segment_id=tokenizer.pad_token_type_id,
|
||||
pad_token_label_id=self.pad_token_label_id,
|
||||
)
|
||||
if local_rank in [-1, 0]:
|
||||
logger.info(f"Saving features into cached file {cached_features_file}")
|
||||
torch.save(self.features, cached_features_file)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
def __getitem__(self, i) -> InputFeatures:
|
||||
return self.features[i]
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
class TFNerDataset:
|
||||
"""
|
||||
This will be superseded by a framework-agnostic approach
|
||||
soon.
|
||||
"""
|
||||
|
||||
features: List[InputFeatures]
|
||||
pad_token_label_id: int = -1
|
||||
# Use cross entropy ignore_index as padding label id so that only
|
||||
# real label ids contribute to the loss later.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
labels: List[str],
|
||||
model_type: str,
|
||||
max_seq_length: Optional[int] = None,
|
||||
overwrite_cache=False,
|
||||
mode: Split = Split.train,
|
||||
):
|
||||
examples = read_examples_from_file(data_dir, mode)
|
||||
# TODO clean up all this to leverage built-in features of tokenizers
|
||||
self.features = convert_examples_to_features(
|
||||
examples,
|
||||
labels,
|
||||
max_seq_length,
|
||||
tokenizer,
|
||||
cls_token_at_end=bool(model_type in ["xlnet"]),
|
||||
# xlnet has a cls token at the end
|
||||
cls_token=tokenizer.cls_token,
|
||||
cls_token_segment_id=2 if model_type in ["xlnet"] else 0,
|
||||
sep_token=tokenizer.sep_token,
|
||||
sep_token_extra=bool(model_type in ["roberta"]),
|
||||
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
|
||||
pad_on_left=bool(tokenizer.padding_side == "left"),
|
||||
pad_token=tokenizer.pad_token_id,
|
||||
pad_token_segment_id=tokenizer.pad_token_type_id,
|
||||
pad_token_label_id=self.pad_token_label_id,
|
||||
)
|
||||
|
||||
def gen():
|
||||
for ex in self.features:
|
||||
if ex.token_type_ids is None:
|
||||
yield (
|
||||
{"input_ids": ex.input_ids, "attention_mask": ex.attention_mask},
|
||||
ex.label_ids,
|
||||
)
|
||||
else:
|
||||
yield (
|
||||
{
|
||||
"input_ids": ex.input_ids,
|
||||
"attention_mask": ex.attention_mask,
|
||||
"token_type_ids": ex.token_type_ids,
|
||||
},
|
||||
ex.label_ids,
|
||||
)
|
||||
|
||||
if "token_type_ids" not in tokenizer.model_input_names:
|
||||
self.dataset = tf.data.Dataset.from_generator(
|
||||
gen,
|
||||
({"input_ids": tf.int32, "attention_mask": tf.int32}, tf.int64),
|
||||
(
|
||||
{"input_ids": tf.TensorShape([None]), "attention_mask": tf.TensorShape([None])},
|
||||
tf.TensorShape([None]),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.dataset = tf.data.Dataset.from_generator(
|
||||
gen,
|
||||
({"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, tf.int64),
|
||||
(
|
||||
{
|
||||
"input_ids": tf.TensorShape([None]),
|
||||
"attention_mask": tf.TensorShape([None]),
|
||||
"token_type_ids": tf.TensorShape([None]),
|
||||
},
|
||||
tf.TensorShape([None]),
|
||||
),
|
||||
)
|
||||
if local_rank in [-1, 0]:
|
||||
logger.info(f"Saving features into cached file {cached_features_file}")
|
||||
torch.save(self.features, cached_features_file)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
def get_dataset(self):
|
||||
return self.dataset
|
||||
|
||||
def __getitem__(self, i) -> InputFeatures:
|
||||
return self.features[i]
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
def __getitem__(self, i) -> InputFeatures:
|
||||
return self.features[i]
|
||||
|
||||
|
||||
def read_examples_from_file(data_dir, mode: Union[Split, str]) -> List[InputExample]:
|
||||
|
@ -1,105 +1,229 @@
|
||||
import os
|
||||
# coding=utf-8
|
||||
""" Fine-tuning the library models for sequence classification."""
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
from transformers import (
|
||||
BertConfig,
|
||||
BertForSequenceClassification,
|
||||
BertTokenizer,
|
||||
TFBertForSequenceClassification,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
EvalPrediction,
|
||||
HfArgumentParser,
|
||||
PreTrainedTokenizer,
|
||||
TFAutoModelForSequenceClassification,
|
||||
TFTrainer,
|
||||
TFTrainingArguments,
|
||||
glue_compute_metrics,
|
||||
glue_convert_examples_to_features,
|
||||
glue_output_modes,
|
||||
glue_processors,
|
||||
glue_tasks_num_labels,
|
||||
)
|
||||
|
||||
|
||||
# script parameters
|
||||
BATCH_SIZE = 32
|
||||
EVAL_BATCH_SIZE = BATCH_SIZE * 2
|
||||
USE_XLA = False
|
||||
USE_AMP = False
|
||||
EPOCHS = 3
|
||||
|
||||
TASK = "mrpc"
|
||||
|
||||
if TASK == "sst-2":
|
||||
TFDS_TASK = "sst2"
|
||||
elif TASK == "sts-b":
|
||||
TFDS_TASK = "stsb"
|
||||
else:
|
||||
TFDS_TASK = TASK
|
||||
|
||||
num_labels = len(glue_processors[TASK]().get_labels())
|
||||
print(num_labels)
|
||||
|
||||
tf.config.optimizer.set_jit(USE_XLA)
|
||||
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})
|
||||
|
||||
# Load tokenizer and model from pretrained model/vocabulary. Specify the number of labels to classify (2+: classification, 1: regression)
|
||||
config = BertConfig.from_pretrained("bert-base-cased", num_labels=num_labels)
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
|
||||
model = TFBertForSequenceClassification.from_pretrained("bert-base-cased", config=config)
|
||||
|
||||
# Load dataset via TensorFlow Datasets
|
||||
data, info = tensorflow_datasets.load(f"glue/{TFDS_TASK}", with_info=True)
|
||||
train_examples = info.splits["train"].num_examples
|
||||
|
||||
# MNLI expects either validation_matched or validation_mismatched
|
||||
valid_examples = info.splits["validation"].num_examples
|
||||
|
||||
# Prepare dataset for GLUE as a tf.data.Dataset instance
|
||||
train_dataset = glue_convert_examples_to_features(data["train"], tokenizer, max_length=128, task=TASK)
|
||||
|
||||
# MNLI expects either validation_matched or validation_mismatched
|
||||
valid_dataset = glue_convert_examples_to_features(data["validation"], tokenizer, max_length=128, task=TASK)
|
||||
train_dataset = train_dataset.shuffle(128).batch(BATCH_SIZE).repeat(-1)
|
||||
valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE)
|
||||
|
||||
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule
|
||||
opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
|
||||
if USE_AMP:
|
||||
# loss scaling is currently required when using mixed precision
|
||||
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, "dynamic")
|
||||
class Split(Enum):
|
||||
train = "train"
|
||||
dev = "validation"
|
||||
test = "test"
|
||||
|
||||
|
||||
if num_labels == 1:
|
||||
loss = tf.keras.losses.MeanSquaredError()
|
||||
else:
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
def get_tfds(
|
||||
task_name: str, tokenizer: PreTrainedTokenizer, max_seq_length: Optional[int] = None, mode: Split = Split.train
|
||||
):
|
||||
if task_name == "mnli-mm" and mode == Split.dev:
|
||||
tfds_name = "mnli_mismatched"
|
||||
elif task_name == "mnli-mm" and mode == Split.train:
|
||||
tfds_name = "mnli"
|
||||
elif task_name == "mnli" and mode == Split.dev:
|
||||
tfds_name = "mnli_matched"
|
||||
elif task_name == "sst-2":
|
||||
tfds_name = "sst2"
|
||||
elif task_name == "sts-b":
|
||||
tfds_name = "stsb"
|
||||
else:
|
||||
tfds_name = task_name
|
||||
|
||||
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
|
||||
model.compile(optimizer=opt, loss=loss, metrics=[metric])
|
||||
ds = tfds.load("glue/" + tfds_name, split=mode.value)
|
||||
|
||||
# Train and evaluate using tf.keras.Model.fit()
|
||||
train_steps = train_examples // BATCH_SIZE
|
||||
valid_steps = valid_examples // EVAL_BATCH_SIZE
|
||||
return glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
|
||||
|
||||
history = model.fit(
|
||||
train_dataset,
|
||||
epochs=EPOCHS,
|
||||
steps_per_epoch=train_steps,
|
||||
validation_data=valid_dataset,
|
||||
validation_steps=valid_steps,
|
||||
)
|
||||
|
||||
# Save TF2 model
|
||||
os.makedirs("./save/", exist_ok=True)
|
||||
model.save_pretrained("./save/")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TASK == "mrpc":
|
||||
# Load the TensorFlow model in PyTorch for inspection
|
||||
# This is to demo the interoperability between the two frameworks, you don't have to
|
||||
# do this in real life (you can run the inference on the TF model).
|
||||
pytorch_model = BertForSequenceClassification.from_pretrained("./save/", from_tf=True)
|
||||
|
||||
# Quickly test a few predictions - MRPC is a paraphrasing task, let's see if our model learned the task
|
||||
sentence_0 = "This research was consistent with his findings."
|
||||
sentence_1 = "His findings were compatible with this research."
|
||||
sentence_2 = "His findings were not compatible with this research."
|
||||
inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors="pt")
|
||||
inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors="pt")
|
||||
@dataclass
|
||||
class GlueDataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
|
||||
pred_1 = pytorch_model(**inputs_1)[0].argmax().item()
|
||||
pred_2 = pytorch_model(**inputs_2)[0].argmax().item()
|
||||
print("sentence_1 is", "a paraphrase" if pred_1 else "not a paraphrase", "of sentence_0")
|
||||
print("sentence_2 is", "a paraphrase" if pred_2 else "not a paraphrase", "of sentence_0")
|
||||
Using `HfArgumentParser` we can turn this class
|
||||
into argparse arguments to be able to specify them on
|
||||
the command line.
|
||||
"""
|
||||
|
||||
task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
|
||||
max_seq_length: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.task_name = self.task_name.lower()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
use_fast: bool = field(default=False, metadata={"help": "Set this flag to use fast tokenization."})
|
||||
# If you want to tweak more attributes on your tokenizer, you should do it in a distinct script,
|
||||
# or just modify its tokenizer_config.json.
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
parser = HfArgumentParser((ModelArguments, GlueDataTrainingArguments, TFTrainingArguments))
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
if (
|
||||
os.path.exists(training_args.output_dir)
|
||||
and os.listdir(training_args.output_dir)
|
||||
and training_args.do_train
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(
|
||||
"n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
training_args.n_gpu,
|
||||
bool(training_args.n_gpu > 1),
|
||||
training_args.fp16,
|
||||
)
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
|
||||
try:
|
||||
num_labels = glue_tasks_num_labels["mnli" if data_args.task_name == "mnli-mm" else data_args.task_name]
|
||||
output_mode = glue_output_modes[data_args.task_name]
|
||||
except KeyError:
|
||||
raise ValueError("Task not found: %s" % (data_args.task_name))
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
#
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=data_args.task_name,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
with training_args.strategy.scope():
|
||||
model = TFAutoModelForSequenceClassification.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_pt=bool(".bin" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
# Get datasets
|
||||
train_dataset = (
|
||||
get_tfds(task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length)
|
||||
if training_args.do_train
|
||||
else None
|
||||
)
|
||||
eval_dataset = (
|
||||
get_tfds(
|
||||
task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length, mode=Split.dev
|
||||
)
|
||||
if training_args.do_eval
|
||||
else None
|
||||
)
|
||||
|
||||
def compute_metrics(p: EvalPrediction) -> Dict:
|
||||
if output_mode == "classification":
|
||||
preds = np.argmax(p.predictions, axis=1)
|
||||
elif output_mode == "regression":
|
||||
preds = np.squeeze(p.predictions)
|
||||
return glue_compute_metrics(data_args.task_name, preds, p.label_ids)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = TFTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
result = trainer.evaluate()
|
||||
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
|
||||
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
|
||||
for key, value in result.items():
|
||||
logger.info(" %s = %s", key, value)
|
||||
writer.write("%s = %s\n" % (key, value))
|
||||
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -145,7 +145,9 @@ from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_xlm import XLMTokenizer
|
||||
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
|
||||
from .trainer_utils import EvalPrediction
|
||||
from .training_args import TrainingArguments
|
||||
from .training_args_tf import TFTrainingArguments
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
@ -502,6 +504,9 @@ if is_tf_available():
|
||||
# Optimization
|
||||
from .optimization_tf import WarmUp, create_optimizer, AdamWeightDecay, GradientAccumulator
|
||||
|
||||
# Trainer
|
||||
from .trainer_tf import TFTrainer
|
||||
|
||||
|
||||
if not is_tf_available() and not is_torch_available():
|
||||
logger.warning(
|
||||
|
@ -21,9 +21,11 @@ import tensorflow as tf
|
||||
|
||||
|
||||
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||||
"""Applys a warmup schedule on a given learning rate decay schedule."""
|
||||
"""Applies a warmup schedule on a given learning rate decay schedule."""
|
||||
|
||||
def __init__(self, initial_learning_rate, decay_schedule_fn, warmup_steps, power=1.0, name=None):
|
||||
def __init__(
|
||||
self, initial_learning_rate, decay_schedule_fn, warmup_steps, power=1.0, name=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.initial_learning_rate = initial_learning_rate
|
||||
self.warmup_steps = warmup_steps
|
||||
@ -56,34 +58,34 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||||
}
|
||||
|
||||
|
||||
def create_optimizer(init_lr, num_train_steps, num_warmup_steps):
|
||||
def create_optimizer(init_lr, num_train_steps, num_warmup_steps, end_lr=0.0, optimizer_type="adamw"):
|
||||
"""Creates an optimizer with learning rate schedule."""
|
||||
# Implements linear decay of the learning rate.
|
||||
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
|
||||
initial_learning_rate=init_lr, decay_steps=num_train_steps, end_learning_rate=0.0
|
||||
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
|
||||
initial_learning_rate=init_lr, decay_steps=num_train_steps, end_learning_rate=end_lr,
|
||||
)
|
||||
if num_warmup_steps:
|
||||
learning_rate_fn = WarmUp(
|
||||
initial_learning_rate=init_lr, decay_schedule_fn=learning_rate_fn, warmup_steps=num_warmup_steps
|
||||
lr_schedule = WarmUp(
|
||||
initial_learning_rate=init_lr, decay_schedule_fn=lr_schedule, warmup_steps=num_warmup_steps,
|
||||
)
|
||||
|
||||
optimizer = AdamWeightDecay(
|
||||
learning_rate=learning_rate_fn,
|
||||
learning_rate=lr_schedule,
|
||||
weight_decay_rate=0.01,
|
||||
beta_1=0.9,
|
||||
beta_2=0.999,
|
||||
epsilon=1e-6,
|
||||
exclude_from_weight_decay=["layer_norm", "bias"],
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
class AdamWeightDecay(tf.keras.optimizers.Adam):
|
||||
"""Adam enables L2 weight decay and clip_by_global_norm on gradients.
|
||||
|
||||
Just adding the square of the weights to the loss function is *not* the
|
||||
correct way of using L2 regularization/weight decay with Adam, since that will
|
||||
interact with the m and v parameters in strange ways.
|
||||
|
||||
Instead we want ot decay the weights in a manner that doesn't interact with
|
||||
the m/v parameters. This is equivalent to adding the square of the weights to
|
||||
the loss with plain (non-momentum) SGD.
|
||||
@ -111,24 +113,26 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
|
||||
def from_config(cls, config):
|
||||
"""Creates an optimizer from its config with WarmUp custom object."""
|
||||
custom_objects = {"WarmUp": WarmUp}
|
||||
return super().from_config(config, custom_objects=custom_objects)
|
||||
return super(AdamWeightDecay, cls).from_config(config, custom_objects=custom_objects)
|
||||
|
||||
def _prepare_local(self, var_device, var_dtype, apply_state):
|
||||
super()._prepare_local(var_device, var_dtype, apply_state)
|
||||
apply_state["weight_decay_rate"] = tf.constant(self.weight_decay_rate, name="adam_weight_decay_rate")
|
||||
super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype, apply_state)
|
||||
apply_state[(var_device, var_dtype)]["weight_decay_rate"] = tf.constant(
|
||||
self.weight_decay_rate, name="adam_weight_decay_rate"
|
||||
)
|
||||
|
||||
def _decay_weights_op(self, var, learning_rate, apply_state):
|
||||
do_decay = self._do_use_weight_decay(var.name)
|
||||
if do_decay:
|
||||
return var.assign_sub(
|
||||
learning_rate * var * apply_state["weight_decay_rate"], use_locking=self._use_locking
|
||||
learning_rate * var * apply_state[(var.device, var.dtype.base_dtype)]["weight_decay_rate"],
|
||||
use_locking=self._use_locking,
|
||||
)
|
||||
return tf.no_op()
|
||||
|
||||
def apply_gradients(self, grads_and_vars, clip_norm, name=None):
|
||||
def apply_gradients(self, grads_and_vars, name=None):
|
||||
grads, tvars = list(zip(*grads_and_vars))
|
||||
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=clip_norm)
|
||||
return super().apply_gradients(zip(grads, tvars))
|
||||
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars), name=name,)
|
||||
|
||||
def _get_lr(self, var_device, var_dtype, apply_state):
|
||||
"""Retrieves the learning rate with the given state."""
|
||||
@ -147,13 +151,13 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
|
||||
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
|
||||
decay = self._decay_weights_op(var, lr_t, apply_state)
|
||||
with tf.control_dependencies([decay]):
|
||||
return super()._resource_apply_dense(grad, var, **kwargs)
|
||||
return super(AdamWeightDecay, self)._resource_apply_dense(grad, var, **kwargs)
|
||||
|
||||
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
|
||||
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
|
||||
decay = self._decay_weights_op(var, lr_t, apply_state)
|
||||
with tf.control_dependencies([decay]):
|
||||
return super()._resource_apply_sparse(grad, var, indices, **kwargs)
|
||||
return super(AdamWeightDecay, self)._resource_apply_sparse(grad, var, indices, **kwargs)
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config()
|
||||
@ -177,71 +181,65 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
|
||||
return True
|
||||
|
||||
|
||||
# Inspired from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
|
||||
# Extracted from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
|
||||
class GradientAccumulator(object):
|
||||
"""Distribution strategies-aware gradient accumulation utility."""
|
||||
"""Gradient accumulation utility.
|
||||
When used with a distribution strategy, the accumulator should be called in a
|
||||
replica context. Gradients will be accumulated locally on each replica and
|
||||
without synchronization. Users should then call ``.gradients``, scale the
|
||||
gradients if required, and pass the result to ``apply_gradients``.
|
||||
"""
|
||||
|
||||
# We use the ON_READ synchronization policy so that no synchronization is
|
||||
# performed on assignment. To get the value, we call .value() which returns the
|
||||
# value on the current replica without synchronization.
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes the accumulator."""
|
||||
self._gradients = []
|
||||
self._accum_steps = tf.Variable(
|
||||
initial_value=0, dtype=tf.int64, trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA
|
||||
)
|
||||
self._accum_steps = None
|
||||
|
||||
@property
|
||||
def step(self):
|
||||
"""Number of accumulated steps."""
|
||||
if self._accum_steps is None:
|
||||
self._accum_steps = tf.Variable(
|
||||
tf.constant(0, dtype=tf.int64), trainable=False, synchronization=tf.VariableSynchronization.ON_READ,
|
||||
)
|
||||
|
||||
return self._accum_steps.value()
|
||||
|
||||
@property
|
||||
def gradients(self):
|
||||
"""The accumulated gradients."""
|
||||
return list(
|
||||
gradient.value() if gradient is not None else gradient for gradient in self._get_replica_gradients()
|
||||
)
|
||||
"""The accumulated gradients on the current replica."""
|
||||
if not self._gradients:
|
||||
raise ValueError("The accumulator should be called first to initialize the gradients")
|
||||
return list(gradient.value() for gradient in self._gradients)
|
||||
|
||||
def __call__(self, gradients):
|
||||
"""Accumulates :obj:`gradients`."""
|
||||
"""Accumulates :obj:`gradients` on the current replica."""
|
||||
if not self._gradients:
|
||||
_ = self.step # Create the step variable.
|
||||
self._gradients.extend(
|
||||
[
|
||||
tf.Variable(tf.zeros_like(gradient), trainable=False) if gradient is not None else gradient
|
||||
tf.Variable(
|
||||
tf.zeros_like(gradient), trainable=False, synchronization=tf.VariableSynchronization.ON_READ,
|
||||
)
|
||||
for gradient in gradients
|
||||
]
|
||||
)
|
||||
|
||||
if len(gradients) != len(self._gradients):
|
||||
raise ValueError("Expected %s gradients, but got %d" % (len(self._gradients), len(gradients)))
|
||||
|
||||
for accum_gradient, gradient in zip(self._get_replica_gradients(), gradients):
|
||||
if accum_gradient is not None and gradient is not None:
|
||||
accum_gradient.assign_add(gradient)
|
||||
for accum_gradient, gradient in zip(self._gradients, gradients):
|
||||
accum_gradient.assign_add(gradient)
|
||||
|
||||
self._accum_steps.assign_add(1)
|
||||
|
||||
def reset(self):
|
||||
"""Resets the accumulated gradients."""
|
||||
if self._gradients:
|
||||
self._accum_steps.assign(0)
|
||||
|
||||
for gradient in self._get_replica_gradients():
|
||||
if gradient is not None:
|
||||
gradient.assign(tf.zeros_like(gradient))
|
||||
|
||||
def _get_replica_gradients(self):
|
||||
if tf.distribute.has_strategy():
|
||||
# In a replica context, we want to accumulate gradients on each replica
|
||||
# without synchronization, so we directly assign the value of the
|
||||
# current replica.
|
||||
replica_context = tf.distribute.get_replica_context()
|
||||
|
||||
if replica_context is None or tf.distribute.get_strategy().num_replicas_in_sync == 1:
|
||||
return self._gradients
|
||||
|
||||
return (
|
||||
gradient.device_map.select_for_current_replica(gradient.values, replica_context)
|
||||
for gradient in self._gradients
|
||||
if gradient is not None
|
||||
)
|
||||
else:
|
||||
return self._gradients
|
||||
"""Resets the accumulated gradients on the current replica."""
|
||||
if not self._gradients:
|
||||
return
|
||||
self._accum_steps.assign(0)
|
||||
for gradient in self._gradients:
|
||||
gradient.assign(tf.zeros_like(gradient))
|
||||
|
@ -6,7 +6,7 @@ import re
|
||||
import shutil
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -20,6 +20,7 @@ from tqdm.auto import tqdm, trange
|
||||
from .data.data_collator import DataCollator, DefaultDataCollator
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .optimization import AdamW, get_linear_schedule_with_warmup
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput
|
||||
from .training_args import TrainingArguments
|
||||
|
||||
|
||||
@ -87,30 +88,6 @@ def torch_distributed_zero_first(local_rank: int):
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
class EvalPrediction(NamedTuple):
|
||||
"""
|
||||
Evaluation output (always contains labels), to be used
|
||||
to compute metrics.
|
||||
"""
|
||||
|
||||
predictions: np.ndarray
|
||||
label_ids: np.ndarray
|
||||
|
||||
|
||||
class PredictionOutput(NamedTuple):
|
||||
predictions: np.ndarray
|
||||
label_ids: Optional[np.ndarray]
|
||||
metrics: Optional[Dict[str, float]]
|
||||
|
||||
|
||||
class TrainOutput(NamedTuple):
|
||||
global_step: int
|
||||
training_loss: float
|
||||
|
||||
|
||||
PREFIX_CHECKPOINT_DIR = "checkpoint"
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
Trainer is a simple but feature-complete training and eval loop for PyTorch,
|
||||
|
429
src/transformers/trainer_tf.py
Normal file
429
src/transformers/trainer_tf.py
Normal file
@ -0,0 +1,429 @@
|
||||
"""Tensorflow trainer class."""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from .modeling_tf_utils import TFPreTrainedModel, shape_list
|
||||
from .optimization_tf import GradientAccumulator, create_optimizer
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput
|
||||
from .training_args_tf import TFTrainingArguments
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TFTrainer:
|
||||
model: TFPreTrainedModel
|
||||
args: TFTrainingArguments
|
||||
# something similar to a PT Dataset.
|
||||
# This is just temporary before to have
|
||||
# a framework-agnostic approach for datasets.
|
||||
train_dataset: Optional[tf.data.Dataset]
|
||||
eval_dataset: Optional[tf.data.Dataset]
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
|
||||
prediction_loss_only: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: TFPreTrainedModel,
|
||||
args: TFTrainingArguments,
|
||||
train_dataset: Optional[tf.data.Dataset] = None,
|
||||
eval_dataset: Optional[tf.data.Dataset] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
prediction_loss_only=False,
|
||||
):
|
||||
self.model = model
|
||||
self.args = args
|
||||
self.train_dataset = train_dataset
|
||||
self.eval_dataset = eval_dataset
|
||||
self.compute_metrics = compute_metrics
|
||||
self.prediction_loss_only = prediction_loss_only
|
||||
self.gradient_accumulator = GradientAccumulator()
|
||||
|
||||
self._setup_training()
|
||||
|
||||
def _setup_training(self) -> None:
|
||||
"""
|
||||
Setup the different steps to train a model:
|
||||
- check if all the data are given
|
||||
- create the proper strategy
|
||||
- create the features
|
||||
- prepare the model settings
|
||||
"""
|
||||
self._prepare_dataset()
|
||||
|
||||
with self.args.strategy.scope():
|
||||
self._create_optimizer()
|
||||
_ = self.optimizer.iterations
|
||||
self._set_loss_and_metric()
|
||||
self._create_checkpoint_manager()
|
||||
self._create_summary_writer()
|
||||
|
||||
def _set_loss_and_metric(self) -> None:
|
||||
"""
|
||||
Create the training loss and metric with their name. Allowed names are those listed
|
||||
in the Tensorflow documentation and those contained in the transformers library.
|
||||
"""
|
||||
try:
|
||||
self.loss = tf.keras.losses.get(
|
||||
{
|
||||
"class_name": self.args.loss_name,
|
||||
"config": {"from_logits": True, "reduction": tf.keras.losses.Reduction.NONE},
|
||||
}
|
||||
)
|
||||
except TypeError:
|
||||
self.loss = tf.keras.losses.get(
|
||||
{"class_name": self.args.loss_name, "config": {"reduction": tf.keras.losses.Reduction.NONE}}
|
||||
)
|
||||
|
||||
def _create_summary_writer(self) -> None:
|
||||
"""
|
||||
Create a summary writer to be able to read the logs in Tensorboard.
|
||||
"""
|
||||
self.writer = tf.summary.create_file_writer(self.args.logging_dir)
|
||||
|
||||
def _prepare_dataset(self) -> None:
|
||||
"""
|
||||
Prepare the training, validation and test data.
|
||||
"""
|
||||
if self.train_dataset is not None:
|
||||
self.num_train_examples = self.train_dataset.reduce(tf.constant(0), lambda x, _: x + 1).numpy()
|
||||
|
||||
if self.args.max_steps > 0:
|
||||
self.train_steps = self.args.max_steps
|
||||
else:
|
||||
self.train_steps: int = math.ceil(self.num_train_examples / self.args.train_batch_size)
|
||||
|
||||
self.train_dataset = (
|
||||
self.train_dataset.cache()
|
||||
.shuffle(self.num_train_examples)
|
||||
.batch(self.args.train_batch_size)
|
||||
.prefetch(tf.data.experimental.AUTOTUNE)
|
||||
)
|
||||
|
||||
if self.args.max_steps > 0:
|
||||
self.train_dataset = self.train_dataset.repeat(-1)
|
||||
|
||||
self.train_dataset = self.args.strategy.experimental_distribute_dataset(self.train_dataset)
|
||||
else:
|
||||
self.train_steps = 0
|
||||
|
||||
if self.eval_dataset is not None:
|
||||
self.eval_dataset = (
|
||||
self.eval_dataset.batch(self.args.eval_batch_size).cache().prefetch(tf.data.experimental.AUTOTUNE)
|
||||
)
|
||||
self.eval_dataset = self.args.strategy.experimental_distribute_dataset(self.eval_dataset)
|
||||
|
||||
def _create_optimizer(self) -> None:
|
||||
"""
|
||||
Create the training optimizer with its name. Allowed names are those listed
|
||||
in the Tensorflow documentation and those contained in the transformers library.
|
||||
"""
|
||||
if self.args.optimizer_name == "adamw":
|
||||
self.optimizer = create_optimizer(self.args.learning_rate, self.train_steps, self.args.warmup_steps)
|
||||
else:
|
||||
try:
|
||||
self.optimizer = tf.keras.optimizers.get(
|
||||
{
|
||||
"class_name": self.args.optimizer_name,
|
||||
"config": {"learning_rate": self.args.learning_rate, "epsilon": self.args.adam_epsilon},
|
||||
}
|
||||
)
|
||||
except TypeError:
|
||||
# This is for the case where the optimizer is not Adam-like such as SGD
|
||||
self.optimizer = tf.keras.optimizers.get(
|
||||
{"class_name": self.args.optimizer_name, "config": {"learning_rate": self.args.learning_rate}}
|
||||
)
|
||||
|
||||
def _create_checkpoint_manager(self, max_to_keep: int = 5, load_model: bool = True) -> None:
|
||||
"""
|
||||
Create a checkpoint manager in order to be able to make the training
|
||||
fault-tolerant.
|
||||
Args:
|
||||
max_to_keep: the maximum number of checkpoints to keep in the checkpoint path.
|
||||
load_model: if we want to start the training from the latest checkpoint.
|
||||
"""
|
||||
ckpt = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)
|
||||
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, PREFIX_CHECKPOINT_DIR, max_to_keep=max_to_keep)
|
||||
|
||||
if load_model:
|
||||
ckpt.restore(self.model.ckpt_manager.latest_checkpoint).expect_partial()
|
||||
|
||||
@tf.function
|
||||
def _evaluate_steps(self, per_replica_features, per_replica_labels):
|
||||
"""
|
||||
One step evaluation across replica.
|
||||
Args:
|
||||
per_replica_features: the batched features.
|
||||
per_replica_labels: the batched labels.
|
||||
Returns:
|
||||
The loss corresponding to the given batch.
|
||||
"""
|
||||
per_replica_loss, per_replica_logits = self.args.strategy.experimental_run_v2(
|
||||
self._run_model, args=(per_replica_features, per_replica_labels, False)
|
||||
)
|
||||
|
||||
try:
|
||||
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=0)
|
||||
except ValueError:
|
||||
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, None)
|
||||
|
||||
return reduced_loss, per_replica_logits
|
||||
|
||||
def _prediction_loop(
|
||||
self, dataset: tf.data.Dataset, description: str, prediction_loss_only: Optional[bool] = None
|
||||
) -> PredictionOutput:
|
||||
logger.info("***** Running %s *****", description)
|
||||
logger.info(" Batch size = %d", self.args.eval_batch_size)
|
||||
|
||||
label_ids: np.ndarray = None
|
||||
preds: np.ndarray = None
|
||||
|
||||
step: int = 1
|
||||
|
||||
for features, labels in dataset:
|
||||
step = tf.convert_to_tensor(step, dtype=tf.int64)
|
||||
loss, logits = self._evaluate_steps(features, labels)
|
||||
loss = tf.reduce_mean(loss)
|
||||
|
||||
if not prediction_loss_only:
|
||||
if self.args.n_gpu > 1:
|
||||
for val in logits.values:
|
||||
if preds is None:
|
||||
preds = val.numpy()
|
||||
else:
|
||||
preds = np.append(preds, val.numpy(), axis=0)
|
||||
|
||||
for val in labels.values:
|
||||
if label_ids is None:
|
||||
label_ids = val.numpy()
|
||||
else:
|
||||
label_ids = np.append(label_ids, val.numpy(), axis=0)
|
||||
else:
|
||||
if preds is None:
|
||||
preds = logits.numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.numpy(), axis=0)
|
||||
|
||||
if label_ids is None:
|
||||
label_ids = labels.numpy()
|
||||
else:
|
||||
label_ids = np.append(label_ids, labels.numpy(), axis=0)
|
||||
|
||||
step += 1
|
||||
|
||||
if self.compute_metrics is not None and preds is not None and label_ids is not None:
|
||||
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
|
||||
else:
|
||||
metrics = {}
|
||||
|
||||
metrics["loss"] = loss.numpy()
|
||||
|
||||
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
||||
|
||||
def evaluate(
|
||||
self, eval_dataset: Optional[tf.data.Dataset] = None, prediction_loss_only: Optional[bool] = None
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
|
||||
"""
|
||||
if eval_dataset is None:
|
||||
eval_dataset = self.eval_dataset
|
||||
|
||||
output = self._prediction_loop(eval_dataset, description="Evaluation")
|
||||
|
||||
return output.metrics
|
||||
|
||||
def train(self) -> None:
|
||||
"""
|
||||
Train method to train the model.
|
||||
"""
|
||||
if self.args.debug:
|
||||
tf.summary.trace_on(graph=True, profiler=True)
|
||||
|
||||
self.gradient_accumulator.reset()
|
||||
|
||||
iterations = self.optimizer.iterations
|
||||
|
||||
if iterations.numpy() > 0:
|
||||
logger.info("Start the training from the last checkpoint")
|
||||
start_epoch = (iterations.numpy() // self.train_steps) + 1
|
||||
else:
|
||||
start_epoch = 1
|
||||
|
||||
tf.summary.experimental.set_step(iterations)
|
||||
|
||||
epochs = 1 if self.args.max_steps > 0 else self.args.num_train_epochs
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", self.num_train_examples)
|
||||
logger.info(" Num Epochs = %d", epochs)
|
||||
logger.info(" Total optimization steps = %d", self.train_steps)
|
||||
|
||||
for epoch in range(start_epoch, int(epochs + 1)):
|
||||
for training_loss in self._training_steps():
|
||||
step = iterations.numpy()
|
||||
|
||||
if self.args.debug:
|
||||
with self.writer.as_default():
|
||||
tf.summary.scalar("loss", training_loss, step=step)
|
||||
|
||||
if step == 1 and self.args.debug:
|
||||
with self.writer.as_default():
|
||||
tf.summary.trace_export(name="training", step=step, profiler_outdir=self.args.logging_dir)
|
||||
|
||||
if self.args.evaluate_during_training and step % self.args.eval_steps == 0:
|
||||
logs = {}
|
||||
results = self.evaluate()
|
||||
|
||||
for key, value in results.items():
|
||||
eval_key = "eval_{}".format(key)
|
||||
logs[eval_key] = value
|
||||
|
||||
if callable(self.optimizer.learning_rate):
|
||||
logs["learning_rate"] = self.optimizer.learning_rate(step).numpy()
|
||||
else:
|
||||
logs["learning_rate"] = self.optimizer.learning_rate.numpy()
|
||||
|
||||
logger.info("Epoch {} Step {} Validation Metrics {}".format(epoch, step, logs))
|
||||
|
||||
with self.writer.as_default():
|
||||
for k, v in logs.items():
|
||||
tf.summary.scalar(k, v, step=step)
|
||||
|
||||
if step % self.args.logging_steps == 0:
|
||||
logger.info("Epoch {} Step {} Train Loss {:.4f}".format(epoch, step, training_loss.numpy()))
|
||||
|
||||
if step % self.args.save_steps == 0:
|
||||
ckpt_save_path = self.model.ckpt_manager.save()
|
||||
logger.info("Saving checkpoint for step {} at {}".format(step, ckpt_save_path))
|
||||
|
||||
if step % self.train_steps == 0:
|
||||
break
|
||||
|
||||
def _training_steps(self):
|
||||
"""
|
||||
Returns a generator over training steps (i.e. parameters update).
|
||||
"""
|
||||
for i, loss in enumerate(self._accumulate_next_gradients()):
|
||||
if i % self.args.gradient_accumulation_steps == 0:
|
||||
self._apply_gradients()
|
||||
yield loss
|
||||
|
||||
@tf.function
|
||||
def _apply_gradients(self):
|
||||
"""Applies the gradients (cross-replica)."""
|
||||
self.args.strategy.experimental_run_v2(self._step)
|
||||
|
||||
def _step(self):
|
||||
"""Applies gradients and resets accumulation."""
|
||||
gradient_scale = self.gradient_accumulator.step * self.args.strategy.num_replicas_in_sync
|
||||
gradients = [
|
||||
gradient / tf.cast(gradient_scale, gradient.dtype) for gradient in self.gradient_accumulator.gradients
|
||||
]
|
||||
gradients = [(tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients]
|
||||
vars = self.model.trainable_variables
|
||||
|
||||
if self.args.mode == "token-classification":
|
||||
vars = [var for var in self.model.trainable_variables if "pooler" not in var.name]
|
||||
|
||||
self.optimizer.apply_gradients(list(zip(gradients, vars)))
|
||||
self.gradient_accumulator.reset()
|
||||
|
||||
def _accumulate_next_gradients(self):
|
||||
"""Accumulates the gradients from the next element in dataset."""
|
||||
iterator = iter(self.train_dataset)
|
||||
|
||||
@tf.function
|
||||
def _accumulate_next():
|
||||
per_replica_features, per_replica_labels = next(iterator)
|
||||
|
||||
return self._accumulate_gradients(per_replica_features, per_replica_labels)
|
||||
|
||||
while True:
|
||||
try:
|
||||
yield _accumulate_next()
|
||||
except tf.errors.OutOfRangeError:
|
||||
break
|
||||
|
||||
def _accumulate_gradients(self, per_replica_features, per_replica_labels):
|
||||
"""Accumulates the gradients across all the replica."""
|
||||
per_replica_loss = self.args.strategy.experimental_run_v2(
|
||||
self._forward, args=(per_replica_features, per_replica_labels)
|
||||
)
|
||||
|
||||
try:
|
||||
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=0)
|
||||
except ValueError:
|
||||
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, None)
|
||||
|
||||
return reduced_loss
|
||||
|
||||
def _forward(self, features, labels):
|
||||
"""Forwards a training example and accumulates the gradients."""
|
||||
per_example_loss, _ = self._run_model(features, labels, True)
|
||||
vars = self.model.trainable_variables
|
||||
|
||||
if self.args.mode == "token-classification":
|
||||
vars = [var for var in self.model.trainable_variables if "pooler" not in var.name]
|
||||
|
||||
gradients = self.optimizer.get_gradients(per_example_loss, vars)
|
||||
|
||||
self.gradient_accumulator(gradients)
|
||||
|
||||
return per_example_loss
|
||||
|
||||
def _run_model(self, features, labels, training):
|
||||
"""
|
||||
Computes the loss of the given features and labels pair.
|
||||
Args:
|
||||
features: the batched features.
|
||||
labels: the batched labels.
|
||||
training: run the model in training mode or not
|
||||
"""
|
||||
if self.args.mode == "sequence-classification" or self.args.mode == "token-classification":
|
||||
logits = self.model(features, training=training)[0]
|
||||
else:
|
||||
logits = self.model(features, training=training)
|
||||
|
||||
if self.args.mode == "token-classification":
|
||||
active_loss = tf.reshape(labels, (-1,)) != -1
|
||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
||||
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
||||
loss = self.loss(labels, reduced_logits)
|
||||
else:
|
||||
loss = self.loss(labels, logits)
|
||||
|
||||
loss += sum(self.model.losses) * (1.0 / self.args.n_gpu)
|
||||
|
||||
return loss, logits
|
||||
|
||||
def predict(self, test_dataset: tf.data.Dataset) -> PredictionOutput:
|
||||
"""
|
||||
Run prediction and return predictions and potential metrics.
|
||||
Depending on the dataset and your use case, your test dataset may contain labels.
|
||||
In that case, this method will also return metrics, like in evaluate().
|
||||
Args:
|
||||
test_dataset: something similar to a PT Dataset. This is just
|
||||
temporary before to have a framework-agnostic approach for datasets.
|
||||
"""
|
||||
test_dataset = test_dataset.batch(self.args.eval_batch_size)
|
||||
test_dataset = self.args.strategy.experimental_distribute_dataset(test_dataset)
|
||||
|
||||
return self._prediction_loop(test_dataset, description="Prediction")
|
||||
|
||||
def save_model(self) -> None:
|
||||
"""
|
||||
Save the pretrained model and create a Tensorflow saved model.
|
||||
"""
|
||||
logger.info("Saving model in {}".format(self.args.output_dir))
|
||||
|
||||
path = os.path.join(self.args.output_dir, "saved_model")
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
self.model.save_pretrained(self.args.output_dir)
|
27
src/transformers/trainer_utils.py
Normal file
27
src/transformers/trainer_utils.py
Normal file
@ -0,0 +1,27 @@
|
||||
from typing import Dict, NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class EvalPrediction(NamedTuple):
|
||||
"""
|
||||
Evaluation output (always contains labels), to be used
|
||||
to compute metrics.
|
||||
"""
|
||||
|
||||
predictions: np.ndarray
|
||||
label_ids: np.ndarray
|
||||
|
||||
|
||||
class PredictionOutput(NamedTuple):
|
||||
predictions: np.ndarray
|
||||
label_ids: Optional[np.ndarray]
|
||||
metrics: Optional[Dict[str, float]]
|
||||
|
||||
|
||||
class TrainOutput(NamedTuple):
|
||||
global_step: int
|
||||
training_loss: float
|
||||
|
||||
|
||||
PREFIX_CHECKPOINT_DIR = "checkpoint"
|
75
src/transformers/training_args_tf.py
Normal file
75
src/transformers/training_args_tf.py
Normal file
@ -0,0 +1,75 @@
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Tuple
|
||||
|
||||
from .file_utils import cached_property, is_tf_available, tf_required
|
||||
from .training_args import TrainingArguments
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
@dataclass
|
||||
class TFTrainingArguments(TrainingArguments):
|
||||
optimizer_name: str = field(
|
||||
default="adam",
|
||||
metadata={
|
||||
"help": 'Name of a Tensorflow optimizer among "adadelta, adagrad, adam, adamax, ftrl, nadam, rmsprop, sgd, adamw"'
|
||||
},
|
||||
)
|
||||
mode: str = field(
|
||||
default="sequence-classification",
|
||||
metadata={"help": 'Type of task, one of "sequence-classification", "token-classification" '},
|
||||
)
|
||||
loss_name: str = field(
|
||||
default="SparseCategoricalCrossentropy",
|
||||
metadata={
|
||||
"help": "Name of a Tensorflow loss. For the list see: https://www.tensorflow.org/api_docs/python/tf/keras/losses"
|
||||
},
|
||||
)
|
||||
eval_steps: int = field(default=1000, metadata={"help": "Run an evaluation every X steps."})
|
||||
debug: bool = field(
|
||||
default=False, metadata={"help": "Activate the trace to record computation graphs and profiling information"}
|
||||
)
|
||||
|
||||
@cached_property
|
||||
@tf_required
|
||||
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
|
||||
logger.info("Tensorflow: setting up strategy")
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
|
||||
if self.no_cuda:
|
||||
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
|
||||
else:
|
||||
try:
|
||||
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
|
||||
except ValueError:
|
||||
tpu = None
|
||||
|
||||
if tpu:
|
||||
tf.config.experimental_connect_to_cluster(tpu)
|
||||
tf.tpu.experimental.initialize_tpu_system(tpu)
|
||||
|
||||
strategy = tf.distribute.experimental.TPUStrategy(tpu)
|
||||
elif len(gpus) == 0:
|
||||
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
|
||||
elif len(gpus) > 1:
|
||||
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
||||
strategy = tf.distribute.MirroredStrategy(gpus)
|
||||
else:
|
||||
raise ValueError("Cannot find the proper strategy please check your environment properties.")
|
||||
|
||||
return strategy
|
||||
|
||||
@property
|
||||
@tf_required
|
||||
def strategy(self) -> "tf.distribute.Strategy":
|
||||
return self._setup_strategy
|
||||
|
||||
@property
|
||||
@tf_required
|
||||
def n_gpu(self) -> int:
|
||||
return self._setup_strategy.num_replicas_in_sync
|
@ -36,14 +36,13 @@ class OptimizationFTest(unittest.TestCase):
|
||||
def testGradientAccumulatorDistributionStrategy(self):
|
||||
context._context = None
|
||||
ops.enable_eager_execution_internal()
|
||||
physical_devices = tf.config.experimental.list_physical_devices("CPU")
|
||||
tf.config.experimental.set_virtual_device_configuration(
|
||||
physical_devices[0],
|
||||
[tf.config.experimental.VirtualDeviceConfiguration(), tf.config.experimental.VirtualDeviceConfiguration()],
|
||||
)
|
||||
|
||||
devices = tf.config.experimental.list_logical_devices(device_type="CPU")
|
||||
strategy = tf.distribute.MirroredStrategy(devices=[device.name for device in devices])
|
||||
physical_devices = tf.config.list_physical_devices("CPU")
|
||||
if len(physical_devices) == 1:
|
||||
tf.config.set_logical_device_configuration(
|
||||
physical_devices[0], [tf.config.LogicalDeviceConfiguration(), tf.config.LogicalDeviceConfiguration()]
|
||||
)
|
||||
devices = tf.config.list_logical_devices(device_type="CPU")
|
||||
strategy = tf.distribute.MirroredStrategy(devices=devices[:2])
|
||||
|
||||
with strategy.scope():
|
||||
accumulator = GradientAccumulator()
|
||||
@ -55,13 +54,14 @@ class OptimizationFTest(unittest.TestCase):
|
||||
accumulator([gradient])
|
||||
|
||||
def apply_on_replica():
|
||||
optimizer.apply_gradients(list(zip(accumulator.gradients, [variable])), 1.0)
|
||||
optimizer.apply_gradients(list(zip(accumulator.gradients, [variable])))
|
||||
|
||||
@tf.function
|
||||
def accumulate(grad1, grad2):
|
||||
with strategy.scope():
|
||||
gradient_placeholder.values[0].assign(grad1)
|
||||
gradient_placeholder.values[1].assign(grad2)
|
||||
local_variables = strategy.experimental_local_results(gradient_placeholder)
|
||||
local_variables[0].assign(grad1)
|
||||
local_variables[1].assign(grad2)
|
||||
strategy.experimental_run_v2(accumulate_on_replica, args=(gradient_placeholder,))
|
||||
|
||||
@tf.function
|
||||
@ -69,15 +69,18 @@ class OptimizationFTest(unittest.TestCase):
|
||||
with strategy.scope():
|
||||
strategy.experimental_run_v2(apply_on_replica)
|
||||
|
||||
def _check_local_values(grad1, grad2):
|
||||
values = strategy.experimental_local_results(accumulator._gradients[0])
|
||||
self.assertListAlmostEqual(values[0].value(), grad1, tol=1e-2)
|
||||
self.assertListAlmostEqual(values[1].value(), grad2, tol=1e-2)
|
||||
|
||||
accumulate([1.0, 2.0], [-1.0, 1.0])
|
||||
accumulate([3.0, -1.0], [-1.0, -1.0])
|
||||
accumulate([-2.0, 2.0], [3.0, -2.0])
|
||||
self.assertEqual(accumulator.step, 3)
|
||||
self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [2.0, 3.0], tol=1e-2)
|
||||
self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [1.0, -2.0], tol=1e-2)
|
||||
_check_local_values([2.0, 3.0], [1.0, -2.0])
|
||||
apply_grad()
|
||||
self.assertListAlmostEqual(variable.value().numpy().tolist(), [4.0, 3.0], tol=1e-2)
|
||||
self.assertListAlmostEqual(variable.value(), [4.0, 3.0], tol=1e-2)
|
||||
accumulator.reset()
|
||||
self.assertEqual(accumulator.step, 0)
|
||||
self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [0.0, 0.0], tol=1e-2)
|
||||
self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [0.0, 0.0], tol=1e-2)
|
||||
_check_local_values([0.0, 0.0], [0.0, 0.0])
|
||||
|
Loading…
Reference in New Issue
Block a user