mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-17 19:48:23 +06:00

This is the result of: $ black --line-length 119 examples templates transformers utils hubconf.py setup.py There's a lot of fairly long lines in the project. As a consequence, I'm picking the longest widely accepted line length, 119 characters. This is also Thomas' preference, because it allows for explicit variable names, to make the code easier to understand.
642 lines
26 KiB
Python
642 lines
26 KiB
Python
# coding=utf-8
|
|
import datetime
|
|
import os
|
|
import math
|
|
import glob
|
|
import re
|
|
import tensorflow as tf
|
|
import collections
|
|
import numpy as np
|
|
from seqeval import metrics
|
|
import _pickle as pickle
|
|
from absl import logging
|
|
from transformers import TF2_WEIGHTS_NAME, BertConfig, BertTokenizer, TFBertForTokenClassification
|
|
from transformers import RobertaConfig, RobertaTokenizer, TFRobertaForTokenClassification
|
|
from transformers import DistilBertConfig, DistilBertTokenizer, TFDistilBertForTokenClassification
|
|
from transformers import create_optimizer, GradientAccumulator
|
|
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
|
from fastprogress import master_bar, progress_bar
|
|
from absl import flags
|
|
from absl import app
|
|
|
|
|
|
ALL_MODELS = sum(
|
|
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)), ()
|
|
)
|
|
|
|
MODEL_CLASSES = {
|
|
"bert": (BertConfig, TFBertForTokenClassification, BertTokenizer),
|
|
"roberta": (RobertaConfig, TFRobertaForTokenClassification, RobertaTokenizer),
|
|
"distilbert": (DistilBertConfig, TFDistilBertForTokenClassification, DistilBertTokenizer),
|
|
}
|
|
|
|
|
|
flags.DEFINE_string(
|
|
"data_dir", None, "The input data dir. Should contain the .conll files (or other data files) " "for the task."
|
|
)
|
|
|
|
flags.DEFINE_string("model_type", None, "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
|
|
|
flags.DEFINE_string(
|
|
"model_name_or_path",
|
|
None,
|
|
"Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
|
)
|
|
|
|
flags.DEFINE_string("output_dir", None, "The output directory where the model checkpoints will be written.")
|
|
|
|
flags.DEFINE_string(
|
|
"labels", "", "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."
|
|
)
|
|
|
|
flags.DEFINE_string("config_name", "", "Pretrained config name or path if not the same as model_name")
|
|
|
|
flags.DEFINE_string("tokenizer_name", "", "Pretrained tokenizer name or path if not the same as model_name")
|
|
|
|
flags.DEFINE_string("cache_dir", "", "Where do you want to store the pre-trained models downloaded from s3")
|
|
|
|
flags.DEFINE_integer(
|
|
"max_seq_length",
|
|
128,
|
|
"The maximum total input sentence length after tokenization. "
|
|
"Sequences longer than this will be truncated, sequences shorter "
|
|
"will be padded.",
|
|
)
|
|
|
|
flags.DEFINE_string(
|
|
"tpu",
|
|
None,
|
|
"The Cloud TPU to use for training. This should be either the name "
|
|
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
|
"url.",
|
|
)
|
|
|
|
flags.DEFINE_integer("num_tpu_cores", 8, "Total number of TPU cores to use.")
|
|
|
|
flags.DEFINE_boolean("do_train", False, "Whether to run training.")
|
|
|
|
flags.DEFINE_boolean("do_eval", False, "Whether to run eval on the dev set.")
|
|
|
|
flags.DEFINE_boolean("do_predict", False, "Whether to run predictions on the test set.")
|
|
|
|
flags.DEFINE_boolean(
|
|
"evaluate_during_training", False, "Whether to run evaluation during training at each logging step."
|
|
)
|
|
|
|
flags.DEFINE_boolean("do_lower_case", False, "Set this flag if you are using an uncased model.")
|
|
|
|
flags.DEFINE_integer("per_device_train_batch_size", 8, "Batch size per GPU/CPU/TPU for training.")
|
|
|
|
flags.DEFINE_integer("per_device_eval_batch_size", 8, "Batch size per GPU/CPU/TPU for evaluation.")
|
|
|
|
flags.DEFINE_integer(
|
|
"gradient_accumulation_steps", 1, "Number of updates steps to accumulate before performing a backward/update pass."
|
|
)
|
|
|
|
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
|
|
|
|
flags.DEFINE_float("weight_decay", 0.0, "Weight decay if we apply some.")
|
|
|
|
flags.DEFINE_float("adam_epsilon", 1e-8, "Epsilon for Adam optimizer.")
|
|
|
|
flags.DEFINE_float("max_grad_norm", 1.0, "Max gradient norm.")
|
|
|
|
flags.DEFINE_integer("num_train_epochs", 3, "Total number of training epochs to perform.")
|
|
|
|
flags.DEFINE_integer(
|
|
"max_steps", -1, "If > 0: set total number of training steps to perform. Override num_train_epochs."
|
|
)
|
|
|
|
flags.DEFINE_integer("warmup_steps", 0, "Linear warmup over warmup_steps.")
|
|
|
|
flags.DEFINE_integer("logging_steps", 50, "Log every X updates steps.")
|
|
|
|
flags.DEFINE_integer("save_steps", 50, "Save checkpoint every X updates steps.")
|
|
|
|
flags.DEFINE_boolean(
|
|
"eval_all_checkpoints",
|
|
False,
|
|
"Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
|
)
|
|
|
|
flags.DEFINE_boolean("no_cuda", False, "Avoid using CUDA when available")
|
|
|
|
flags.DEFINE_boolean("overwrite_output_dir", False, "Overwrite the content of the output directory")
|
|
|
|
flags.DEFINE_boolean("overwrite_cache", False, "Overwrite the cached training and evaluation sets")
|
|
|
|
flags.DEFINE_integer("seed", 42, "random seed for initialization")
|
|
|
|
flags.DEFINE_boolean("fp16", False, "Whether to use 16-bit (mixed) precision instead of 32-bit")
|
|
|
|
flags.DEFINE_string(
|
|
"gpus",
|
|
"0",
|
|
"Comma separated list of gpus devices. If only one, switch to single "
|
|
"gpu strategy, if None takes all the gpus available.",
|
|
)
|
|
|
|
|
|
def train(
|
|
args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id
|
|
):
|
|
if args["max_steps"] > 0:
|
|
num_train_steps = args["max_steps"] * args["gradient_accumulation_steps"]
|
|
args["num_train_epochs"] = 1
|
|
else:
|
|
num_train_steps = (
|
|
math.ceil(num_train_examples / train_batch_size)
|
|
// args["gradient_accumulation_steps"]
|
|
* args["num_train_epochs"]
|
|
)
|
|
|
|
writer = tf.summary.create_file_writer("/tmp/mylogs")
|
|
|
|
with strategy.scope():
|
|
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
|
|
optimizer = create_optimizer(args["learning_rate"], num_train_steps, args["warmup_steps"])
|
|
|
|
if args["fp16"]:
|
|
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, "dynamic")
|
|
|
|
loss_metric = tf.keras.metrics.Mean(name="loss", dtype=tf.float32)
|
|
gradient_accumulator = GradientAccumulator()
|
|
|
|
logging.info("***** Running training *****")
|
|
logging.info(" Num examples = %d", num_train_examples)
|
|
logging.info(" Num Epochs = %d", args["num_train_epochs"])
|
|
logging.info(" Instantaneous batch size per device = %d", args["per_device_train_batch_size"])
|
|
logging.info(
|
|
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
|
train_batch_size * args["gradient_accumulation_steps"],
|
|
)
|
|
logging.info(" Gradient Accumulation steps = %d", args["gradient_accumulation_steps"])
|
|
logging.info(" Total training steps = %d", num_train_steps)
|
|
|
|
model.summary()
|
|
|
|
@tf.function
|
|
def apply_gradients():
|
|
grads_and_vars = []
|
|
|
|
for gradient, variable in zip(gradient_accumulator.gradients, model.trainable_variables):
|
|
if gradient is not None:
|
|
scaled_gradient = gradient / (args["n_device"] * args["gradient_accumulation_steps"])
|
|
grads_and_vars.append((scaled_gradient, variable))
|
|
else:
|
|
grads_and_vars.append((gradient, variable))
|
|
|
|
optimizer.apply_gradients(grads_and_vars, args["max_grad_norm"])
|
|
gradient_accumulator.reset()
|
|
|
|
@tf.function
|
|
def train_step(train_features, train_labels):
|
|
def step_fn(train_features, train_labels):
|
|
inputs = {"attention_mask": train_features["input_mask"], "training": True}
|
|
|
|
if args["model_type"] != "distilbert":
|
|
inputs["token_type_ids"] = (
|
|
train_features["segment_ids"] if args["model_type"] in ["bert", "xlnet"] else None
|
|
)
|
|
|
|
with tf.GradientTape() as tape:
|
|
logits = model(train_features["input_ids"], **inputs)[0]
|
|
logits = tf.reshape(logits, (-1, len(labels) + 1))
|
|
active_loss = tf.reshape(train_features["input_mask"], (-1,))
|
|
active_logits = tf.boolean_mask(logits, active_loss)
|
|
train_labels = tf.reshape(train_labels, (-1,))
|
|
active_labels = tf.boolean_mask(train_labels, active_loss)
|
|
cross_entropy = loss_fct(active_labels, active_logits)
|
|
loss = tf.reduce_sum(cross_entropy) * (1.0 / train_batch_size)
|
|
grads = tape.gradient(loss, model.trainable_variables)
|
|
|
|
gradient_accumulator(grads)
|
|
|
|
return cross_entropy
|
|
|
|
per_example_losses = strategy.experimental_run_v2(step_fn, args=(train_features, train_labels))
|
|
mean_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_example_losses, axis=0)
|
|
|
|
return mean_loss
|
|
|
|
current_time = datetime.datetime.now()
|
|
train_iterator = master_bar(range(args["num_train_epochs"]))
|
|
global_step = 0
|
|
logging_loss = 0.0
|
|
|
|
for epoch in train_iterator:
|
|
epoch_iterator = progress_bar(
|
|
train_dataset, total=num_train_steps, parent=train_iterator, display=args["n_device"] > 1
|
|
)
|
|
step = 1
|
|
|
|
with strategy.scope():
|
|
for train_features, train_labels in epoch_iterator:
|
|
loss = train_step(train_features, train_labels)
|
|
|
|
if step % args["gradient_accumulation_steps"] == 0:
|
|
strategy.experimental_run_v2(apply_gradients)
|
|
|
|
loss_metric(loss)
|
|
|
|
global_step += 1
|
|
|
|
if args["logging_steps"] > 0 and global_step % args["logging_steps"] == 0:
|
|
# Log metrics
|
|
if (
|
|
args["n_device"] == 1 and args["evaluate_during_training"]
|
|
): # Only evaluate when single GPU otherwise metrics may not average well
|
|
y_true, y_pred, eval_loss = evaluate(
|
|
args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
|
|
)
|
|
report = metrics.classification_report(y_true, y_pred, digits=4)
|
|
|
|
logging.info("Eval at step " + str(global_step) + "\n" + report)
|
|
logging.info("eval_loss: " + str(eval_loss))
|
|
|
|
precision = metrics.precision_score(y_true, y_pred)
|
|
recall = metrics.recall_score(y_true, y_pred)
|
|
f1 = metrics.f1_score(y_true, y_pred)
|
|
|
|
with writer.as_default():
|
|
tf.summary.scalar("eval_loss", eval_loss, global_step)
|
|
tf.summary.scalar("precision", precision, global_step)
|
|
tf.summary.scalar("recall", recall, global_step)
|
|
tf.summary.scalar("f1", f1, global_step)
|
|
|
|
lr = optimizer.learning_rate
|
|
learning_rate = lr(step)
|
|
|
|
with writer.as_default():
|
|
tf.summary.scalar("lr", learning_rate, global_step)
|
|
tf.summary.scalar(
|
|
"loss", (loss_metric.result() - logging_loss) / args["logging_steps"], global_step
|
|
)
|
|
|
|
logging_loss = loss_metric.result()
|
|
|
|
with writer.as_default():
|
|
tf.summary.scalar("loss", loss_metric.result(), step=step)
|
|
|
|
if args["save_steps"] > 0 and global_step % args["save_steps"] == 0:
|
|
# Save model checkpoint
|
|
output_dir = os.path.join(args["output_dir"], "checkpoint-{}".format(global_step))
|
|
|
|
if not os.path.exists(output_dir):
|
|
os.makedirs(output_dir)
|
|
|
|
model.save_pretrained(output_dir)
|
|
logging.info("Saving model checkpoint to %s", output_dir)
|
|
|
|
train_iterator.child.comment = f"loss : {loss_metric.result()}"
|
|
step += 1
|
|
|
|
train_iterator.write(f"loss epoch {epoch + 1}: {loss_metric.result()}")
|
|
|
|
loss_metric.reset_states()
|
|
|
|
logging.info(" Training took time = {}".format(datetime.datetime.now() - current_time))
|
|
|
|
|
|
def evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode):
|
|
eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
|
|
eval_dataset, size = load_and_cache_examples(
|
|
args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode=mode
|
|
)
|
|
eval_dataset = strategy.experimental_distribute_dataset(eval_dataset)
|
|
preds = None
|
|
num_eval_steps = math.ceil(size / eval_batch_size)
|
|
master = master_bar(range(1))
|
|
eval_iterator = progress_bar(eval_dataset, total=num_eval_steps, parent=master, display=args["n_device"] > 1)
|
|
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
|
|
loss = 0.0
|
|
|
|
logging.info("***** Running evaluation *****")
|
|
logging.info(" Num examples = %d", size)
|
|
logging.info(" Batch size = %d", eval_batch_size)
|
|
|
|
for eval_features, eval_labels in eval_iterator:
|
|
inputs = {"attention_mask": eval_features["input_mask"], "training": False}
|
|
|
|
if args["model_type"] != "distilbert":
|
|
inputs["token_type_ids"] = (
|
|
eval_features["segment_ids"] if args["model_type"] in ["bert", "xlnet"] else None
|
|
)
|
|
|
|
with strategy.scope():
|
|
logits = model(eval_features["input_ids"], **inputs)[0]
|
|
tmp_logits = tf.reshape(logits, (-1, len(labels) + 1))
|
|
active_loss = tf.reshape(eval_features["input_mask"], (-1,))
|
|
active_logits = tf.boolean_mask(tmp_logits, active_loss)
|
|
tmp_eval_labels = tf.reshape(eval_labels, (-1,))
|
|
active_labels = tf.boolean_mask(tmp_eval_labels, active_loss)
|
|
cross_entropy = loss_fct(active_labels, active_logits)
|
|
loss += tf.reduce_sum(cross_entropy) * (1.0 / eval_batch_size)
|
|
|
|
if preds is None:
|
|
preds = logits.numpy()
|
|
label_ids = eval_labels.numpy()
|
|
else:
|
|
preds = np.append(preds, logits.numpy(), axis=0)
|
|
label_ids = np.append(label_ids, eval_labels.numpy(), axis=0)
|
|
|
|
preds = np.argmax(preds, axis=2)
|
|
y_pred = [[] for _ in range(label_ids.shape[0])]
|
|
y_true = [[] for _ in range(label_ids.shape[0])]
|
|
loss = loss / num_eval_steps
|
|
|
|
for i in range(label_ids.shape[0]):
|
|
for j in range(label_ids.shape[1]):
|
|
if label_ids[i, j] != pad_token_label_id:
|
|
y_pred[i].append(labels[preds[i, j] - 1])
|
|
y_true[i].append(labels[label_ids[i, j] - 1])
|
|
|
|
return y_true, y_pred, loss.numpy()
|
|
|
|
|
|
def load_cache(cached_file, max_seq_length):
|
|
name_to_features = {
|
|
"input_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
|
|
"input_mask": tf.io.FixedLenFeature([max_seq_length], tf.int64),
|
|
"segment_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
|
|
"label_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
|
|
}
|
|
|
|
def _decode_record(record):
|
|
example = tf.io.parse_single_example(record, name_to_features)
|
|
features = {}
|
|
features["input_ids"] = example["input_ids"]
|
|
features["input_mask"] = example["input_mask"]
|
|
features["segment_ids"] = example["segment_ids"]
|
|
|
|
return features, example["label_ids"]
|
|
|
|
d = tf.data.TFRecordDataset(cached_file)
|
|
d = d.map(_decode_record, num_parallel_calls=4)
|
|
count = d.reduce(0, lambda x, _: x + 1)
|
|
|
|
return d, count.numpy()
|
|
|
|
|
|
def save_cache(features, cached_features_file):
|
|
writer = tf.io.TFRecordWriter(cached_features_file)
|
|
|
|
for (ex_index, feature) in enumerate(features):
|
|
if ex_index % 5000 == 0:
|
|
logging.info("Writing example %d of %d" % (ex_index, len(features)))
|
|
|
|
def create_int_feature(values):
|
|
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
|
return f
|
|
|
|
record_feature = collections.OrderedDict()
|
|
record_feature["input_ids"] = create_int_feature(feature.input_ids)
|
|
record_feature["input_mask"] = create_int_feature(feature.input_mask)
|
|
record_feature["segment_ids"] = create_int_feature(feature.segment_ids)
|
|
record_feature["label_ids"] = create_int_feature(feature.label_ids)
|
|
|
|
tf_example = tf.train.Example(features=tf.train.Features(feature=record_feature))
|
|
|
|
writer.write(tf_example.SerializeToString())
|
|
|
|
writer.close()
|
|
|
|
|
|
def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, batch_size, mode):
|
|
drop_remainder = True if args["tpu"] or mode == "train" else False
|
|
|
|
# Load data features from cache or dataset file
|
|
cached_features_file = os.path.join(
|
|
args["data_dir"],
|
|
"cached_{}_{}_{}.tf_record".format(
|
|
mode, list(filter(None, args["model_name_or_path"].split("/"))).pop(), str(args["max_seq_length"])
|
|
),
|
|
)
|
|
if os.path.exists(cached_features_file) and not args["overwrite_cache"]:
|
|
logging.info("Loading features from cached file %s", cached_features_file)
|
|
dataset, size = load_cache(cached_features_file, args["max_seq_length"])
|
|
else:
|
|
logging.info("Creating features from dataset file at %s", args["data_dir"])
|
|
examples = read_examples_from_file(args["data_dir"], mode)
|
|
features = convert_examples_to_features(
|
|
examples,
|
|
labels,
|
|
args["max_seq_length"],
|
|
tokenizer,
|
|
cls_token_at_end=bool(args["model_type"] in ["xlnet"]),
|
|
# xlnet has a cls token at the end
|
|
cls_token=tokenizer.cls_token,
|
|
cls_token_segment_id=2 if args["model_type"] in ["xlnet"] else 0,
|
|
sep_token=tokenizer.sep_token,
|
|
sep_token_extra=bool(args["model_type"] in ["roberta"]),
|
|
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
|
|
pad_on_left=bool(args["model_type"] in ["xlnet"]),
|
|
# pad on the left for xlnet
|
|
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
|
pad_token_segment_id=4 if args["model_type"] in ["xlnet"] else 0,
|
|
pad_token_label_id=pad_token_label_id,
|
|
)
|
|
logging.info("Saving features into cached file %s", cached_features_file)
|
|
save_cache(features, cached_features_file)
|
|
dataset, size = load_cache(cached_features_file, args["max_seq_length"])
|
|
|
|
if mode == "train":
|
|
dataset = dataset.repeat()
|
|
dataset = dataset.shuffle(buffer_size=8192, seed=args["seed"])
|
|
|
|
dataset = dataset.batch(batch_size, drop_remainder)
|
|
dataset = dataset.prefetch(buffer_size=batch_size)
|
|
|
|
return dataset, size
|
|
|
|
|
|
def main(_):
|
|
logging.set_verbosity(logging.INFO)
|
|
args = flags.FLAGS.flag_values_dict()
|
|
|
|
if (
|
|
os.path.exists(args["output_dir"])
|
|
and os.listdir(args["output_dir"])
|
|
and args["do_train"]
|
|
and not args["overwrite_output_dir"]
|
|
):
|
|
raise ValueError(
|
|
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
|
args["output_dir"]
|
|
)
|
|
)
|
|
|
|
if args["fp16"]:
|
|
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
|
|
|
|
if args["tpu"]:
|
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args["tpu"])
|
|
tf.config.experimental_connect_to_cluster(resolver)
|
|
tf.tpu.experimental.initialize_tpu_system(resolver)
|
|
strategy = tf.distribute.experimental.TPUStrategy(resolver)
|
|
args["n_device"] = args["num_tpu_cores"]
|
|
elif len(args["gpus"].split(",")) > 1:
|
|
args["n_device"] = len([f"/gpu:{gpu}" for gpu in args["gpus"].split(",")])
|
|
strategy = tf.distribute.MirroredStrategy(devices=[f"/gpu:{gpu}" for gpu in args["gpus"].split(",")])
|
|
elif args["no_cuda"]:
|
|
args["n_device"] = 1
|
|
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
|
|
else:
|
|
args["n_device"] = len(args["gpus"].split(","))
|
|
strategy = tf.distribute.OneDeviceStrategy(device="/gpu:" + args["gpus"].split(",")[0])
|
|
|
|
logging.warning(
|
|
"n_device: %s, distributed training: %s, 16-bits training: %s",
|
|
args["n_device"],
|
|
bool(args["n_device"] > 1),
|
|
args["fp16"],
|
|
)
|
|
|
|
labels = get_labels(args["labels"])
|
|
num_labels = len(labels) + 1
|
|
pad_token_label_id = 0
|
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args["model_type"]]
|
|
config = config_class.from_pretrained(
|
|
args["config_name"] if args["config_name"] else args["model_name_or_path"],
|
|
num_labels=num_labels,
|
|
cache_dir=args["cache_dir"] if args["cache_dir"] else None,
|
|
)
|
|
|
|
logging.info("Training/evaluation parameters %s", args)
|
|
|
|
# Training
|
|
if args["do_train"]:
|
|
tokenizer = tokenizer_class.from_pretrained(
|
|
args["tokenizer_name"] if args["tokenizer_name"] else args["model_name_or_path"],
|
|
do_lower_case=args["do_lower_case"],
|
|
cache_dir=args["cache_dir"] if args["cache_dir"] else None,
|
|
)
|
|
|
|
with strategy.scope():
|
|
model = model_class.from_pretrained(
|
|
args["model_name_or_path"],
|
|
from_pt=bool(".bin" in args["model_name_or_path"]),
|
|
config=config,
|
|
cache_dir=args["cache_dir"] if args["cache_dir"] else None,
|
|
)
|
|
model.layers[-1].activation = tf.keras.activations.softmax
|
|
|
|
train_batch_size = args["per_device_train_batch_size"] * args["n_device"]
|
|
train_dataset, num_train_examples = load_and_cache_examples(
|
|
args, tokenizer, labels, pad_token_label_id, train_batch_size, mode="train"
|
|
)
|
|
train_dataset = strategy.experimental_distribute_dataset(train_dataset)
|
|
train(
|
|
args,
|
|
strategy,
|
|
train_dataset,
|
|
tokenizer,
|
|
model,
|
|
num_train_examples,
|
|
labels,
|
|
train_batch_size,
|
|
pad_token_label_id,
|
|
)
|
|
|
|
if not os.path.exists(args["output_dir"]):
|
|
os.makedirs(args["output_dir"])
|
|
|
|
logging.info("Saving model to %s", args["output_dir"])
|
|
|
|
model.save_pretrained(args["output_dir"])
|
|
tokenizer.save_pretrained(args["output_dir"])
|
|
|
|
# Evaluation
|
|
if args["do_eval"]:
|
|
tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
|
checkpoints = []
|
|
results = []
|
|
|
|
if args["eval_all_checkpoints"]:
|
|
checkpoints = list(
|
|
os.path.dirname(c)
|
|
for c in sorted(
|
|
glob.glob(args["output_dir"] + "/**/" + TF2_WEIGHTS_NAME, recursive=True),
|
|
key=lambda f: int("".join(filter(str.isdigit, f)) or -1),
|
|
)
|
|
)
|
|
|
|
logging.info("Evaluate the following checkpoints: %s", checkpoints)
|
|
|
|
if len(checkpoints) == 0:
|
|
checkpoints.append(args["output_dir"])
|
|
|
|
for checkpoint in checkpoints:
|
|
global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"
|
|
|
|
with strategy.scope():
|
|
model = model_class.from_pretrained(checkpoint)
|
|
|
|
y_true, y_pred, eval_loss = evaluate(
|
|
args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
|
|
)
|
|
report = metrics.classification_report(y_true, y_pred, digits=4)
|
|
|
|
if global_step:
|
|
results.append({global_step + "_report": report, global_step + "_loss": eval_loss})
|
|
|
|
output_eval_file = os.path.join(args["output_dir"], "eval_results.txt")
|
|
|
|
with tf.io.gfile.GFile(output_eval_file, "w") as writer:
|
|
for res in results:
|
|
for key, val in res.items():
|
|
if "loss" in key:
|
|
logging.info(key + " = " + str(val))
|
|
writer.write(key + " = " + str(val))
|
|
writer.write("\n")
|
|
else:
|
|
logging.info(key)
|
|
logging.info("\n" + report)
|
|
writer.write(key + "\n")
|
|
writer.write(report)
|
|
writer.write("\n")
|
|
|
|
if args["do_predict"]:
|
|
tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
|
model = model_class.from_pretrained(args["output_dir"])
|
|
eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
|
|
predict_dataset, _ = load_and_cache_examples(
|
|
args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test"
|
|
)
|
|
y_true, y_pred, pred_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="test")
|
|
output_test_results_file = os.path.join(args["output_dir"], "test_results.txt")
|
|
output_test_predictions_file = os.path.join(args["output_dir"], "test_predictions.txt")
|
|
report = metrics.classification_report(y_true, y_pred, digits=4)
|
|
|
|
with tf.io.gfile.GFile(output_test_results_file, "w") as writer:
|
|
report = metrics.classification_report(y_true, y_pred, digits=4)
|
|
|
|
logging.info("\n" + report)
|
|
|
|
writer.write(report)
|
|
writer.write("\n\nloss = " + str(pred_loss))
|
|
|
|
with tf.io.gfile.GFile(output_test_predictions_file, "w") as writer:
|
|
with tf.io.gfile.GFile(os.path.join(args["data_dir"], "test.txt"), "r") as f:
|
|
example_id = 0
|
|
|
|
for line in f:
|
|
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
|
writer.write(line)
|
|
|
|
if not y_pred[example_id]:
|
|
example_id += 1
|
|
elif y_pred[example_id]:
|
|
output_line = line.split()[0] + " " + y_pred[example_id].pop(0) + "\n"
|
|
writer.write(output_line)
|
|
else:
|
|
logging.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
flags.mark_flag_as_required("data_dir")
|
|
flags.mark_flag_as_required("output_dir")
|
|
flags.mark_flag_as_required("model_name_or_path")
|
|
flags.mark_flag_as_required("model_type")
|
|
app.run(main)
|