From 9200a759d782a87530765fb32f52b6248c7f4d03 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Thu, 5 Dec 2019 12:56:43 +0100 Subject: [PATCH] Add few tests on the TF optimization file with some info in the documentation. Complete the README. --- .../main_classes/optimizer_schedules.rst | 24 +++++ examples/README.md | 77 +++++++++++++++- examples/run_tf_ner.py | 7 +- transformers/tests/optimization_tf_test.py | 89 +++++++++++++++++++ 4 files changed, 191 insertions(+), 6 deletions(-) create mode 100644 transformers/tests/optimization_tf_test.py diff --git a/docs/source/main_classes/optimizer_schedules.rst b/docs/source/main_classes/optimizer_schedules.rst index b30a2e0e2e1..22ed1b28fb0 100644 --- a/docs/source/main_classes/optimizer_schedules.rst +++ b/docs/source/main_classes/optimizer_schedules.rst @@ -5,6 +5,7 @@ The ``.optimization`` module provides: - an optimizer with weight decay fixed that can be used to fine-tuned models, and - several schedules in the form of schedule objects that inherit from ``_LRSchedule``: +- a gradient accumulation class to accumulate the gradients of multiple batches ``AdamW`` ~~~~~~~~~~~~~~~~ @@ -12,6 +13,15 @@ The ``.optimization`` module provides: .. autoclass:: transformers.AdamW :members: +``AdamWeightDecay`` +~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.AdamWeightDecay + :members: + +.. autofunction:: transformers.create_optimizer + :members: + Schedules ---------------------------------------------------- @@ -49,3 +59,17 @@ Learning Rate Schedules .. image:: /imgs/warmup_linear_schedule.png :target: /imgs/warmup_linear_schedule.png :alt: + +``Warmup`` +~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.Warmup + :members: + +Gradient Strategies +---------------------------------------------------- + +``GradientAccumulator`` +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.GradientAccumulator diff --git a/examples/README.md b/examples/README.md index 960b218f112..2dd66539162 100644 --- a/examples/README.md +++ b/examples/README.md @@ -465,7 +465,8 @@ Training with the previously defined hyper-parameters yields the following resul ## Named Entity Recognition -Based on the script [`run_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/run_ner.py). +Based on the scripts [`run_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/run_ner.py) for Pytorch and +[`run_tf_ner.py`(https://github.com/huggingface/transformers/blob/master/examples/run_tf_ner.py)] for Tensorflow 2. This example fine-tune Bert Multilingual on GermEval 2014 (German NER). Details and results for the fine-tuning provided by @stefan-it. @@ -510,7 +511,7 @@ The GermEval 2014 dataset has much more labels than CoNLL-2002/2003 datasets, so cat train.txt dev.txt test.txt | cut -d " " -f 2 | grep -v "^$"| sort | uniq > labels.txt ``` -### Training +### Prepare the run Additional environment variables must be set: @@ -522,6 +523,8 @@ export SAVE_STEPS=750 export SEED=1 ``` +### Run the Pytorch version + To start training, just run: ```bash @@ -542,7 +545,7 @@ python3 run_ner.py --data_dir ./ \ If your GPU supports half-precision training, just add the `--fp16` flag. After training, the model will be both evaluated on development and test datasets. -### Evaluation +#### Evaluation Evaluation on development dataset outputs the following for our example: @@ -564,7 +567,7 @@ On the test dataset the following results could be achieved: 10/04/2019 00:42:42 - INFO - __main__ - recall = 0.8624150210424085 ``` -### Comparing BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased) +#### Comparing BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased) Here is a small comparison between BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased) with the same hyperparameters as specified in the [example documentation](https://huggingface.co/transformers/examples.html#named-entity-recognition) (one run): @@ -574,6 +577,72 @@ Here is a small comparison between BERT (large, cased), RoBERTa (large, cased) a | `roberta-large` | 95.96 | 91.87 | `distilbert-base-uncased` | 94.34 | 90.32 +### Run the Tensorflow 2 version + +To start training, just run: + +```bash +python3 run_tf_ner.py --data_dir ./ \ +--model_type bert \ +--labels ./labels.txt \ +--model_name_or_path $BERT_MODEL \ +--output_dir $OUTPUT_DIR \ +--max_seq_length $MAX_LENGTH \ +--num_train_epochs $NUM_EPOCHS \ +--per_device_train_batch_size $BATCH_SIZE \ +--save_steps $SAVE_STEPS \ +--seed $SEED \ +--do_train \ +--do_eval \ +--do_predict +``` + +Such as the Pytorch version, if your GPU supports half-precision training, just add the `--fp16` flag. After training, the model will be both evaluated on development and test datasets. + +#### Evaluation + +Evaluation on development dataset outputs the following for our example: +```bash + precision recall f1-score support + + LOCderiv 0.7619 0.6154 0.6809 52 + PERpart 0.8724 0.8997 0.8858 4057 + OTHpart 0.9360 0.9466 0.9413 711 + ORGpart 0.7015 0.6989 0.7002 269 + LOCpart 0.7668 0.8488 0.8057 496 + LOC 0.8745 0.9191 0.8963 235 + ORGderiv 0.7723 0.8571 0.8125 91 + OTHderiv 0.4800 0.6667 0.5581 18 + OTH 0.5789 0.6875 0.6286 16 + PERderiv 0.5385 0.3889 0.4516 18 + PER 0.5000 0.5000 0.5000 2 + ORG 0.0000 0.0000 0.0000 3 + +micro avg 0.8574 0.8862 0.8715 5968 +macro avg 0.8575 0.8862 0.8713 5968 +``` + +On the test dataset the following results could be achieved: +```bash + precision recall f1-score support + + PERpart 0.8847 0.8944 0.8896 9397 + OTHpart 0.9376 0.9353 0.9365 1639 + ORGpart 0.7307 0.7044 0.7173 697 + LOC 0.9133 0.9394 0.9262 561 + LOCpart 0.8058 0.8157 0.8107 1150 + ORG 0.0000 0.0000 0.0000 8 + OTHderiv 0.5882 0.4762 0.5263 42 + PERderiv 0.6571 0.5227 0.5823 44 + OTH 0.4906 0.6667 0.5652 39 + ORGderiv 0.7016 0.7791 0.7383 172 + LOCderiv 0.8256 0.6514 0.7282 109 + PER 0.0000 0.0000 0.0000 11 + +micro avg 0.8722 0.8774 0.8748 13869 +macro avg 0.8712 0.8774 0.8740 13869 +``` + ## Abstractive summarization Based on the script diff --git a/examples/run_tf_ner.py b/examples/run_tf_ner.py index ef1fcf6aa48..eb284f4c2a7 100644 --- a/examples/run_tf_ner.py +++ b/examples/run_tf_ner.py @@ -540,6 +540,9 @@ def main(_): checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args['output_dir'] + "/**/" + TF2_WEIGHTS_NAME, recursive=True), key=lambda f: int(''.join(filter(str.isdigit, f)) or -1))) logging.info("Evaluate the following checkpoints: %s", checkpoints) + + if len(checkpoints) == 0: + checkpoints.append(args['output_dir']) for checkpoint in checkpoints: global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final" @@ -572,10 +575,10 @@ def main(_): if args['do_predict']: tokenizer = tokenizer_class.from_pretrained(args['output_dir'], do_lower_case=args['do_lower_case']) model = model_class.from_pretrained(args['output_dir']) - eval_batch_size = args['per_gpu_eval_batch_size'] * args['n_device'] + eval_batch_size = args['per_device_eval_batch_size'] * args['n_device'] predict_dataset, _ = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test") y_true, y_pred, pred_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="test") - output_test_results_file = os.path.join(args.output_dir, "test_results.txt") + output_test_results_file = os.path.join(args['output_dir'], "test_results.txt") output_test_predictions_file = os.path.join(args['output_dir'], "test_predictions.txt") report = metrics.classification_report(y_true, y_pred, digits=4) diff --git a/transformers/tests/optimization_tf_test.py b/transformers/tests/optimization_tf_test.py new file mode 100644 index 00000000000..ac5109cb560 --- /dev/null +++ b/transformers/tests/optimization_tf_test.py @@ -0,0 +1,89 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +import pytest + +from transformers import is_tf_available + +if is_tf_available(): + import tensorflow as tf + from tensorflow.python.eager import context + from tensorflow.python.framework import ops + from transformers import (create_optimizer, GradientAccumulator) +else: + pytestmark = pytest.mark.skip("Require TensorFlow") + +class OptimizationFTest(unittest.TestCase): + def assertListAlmostEqual(self, list1, list2, tol): + self.assertEqual(len(list1), len(list2)) + for a, b in zip(list1, list2): + self.assertAlmostEqual(a, b, delta=tol) + + def testGradientAccumulator(self): + accumulator = GradientAccumulator() + accumulator([tf.constant([1.0, 2.0])]) + accumulator([tf.constant([-2.0, 1.0])]) + accumulator([tf.constant([-1.0, 2.0])]) + with self.assertRaises(ValueError): + accumulator([tf.constant([1.0, 1.0]), tf.constant([2.0, 2.0])]) + self.assertEqual(accumulator.step, 3) + self.assertEqual(len(accumulator.gradients), 1) + self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [-2.0, 5.0], tol=1e-2) + accumulator.reset() + self.assertEqual(accumulator.step, 0) + self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [0.0, 0.0], tol=1e-2) + + def testGradientAccumulatorDistributionStrategy(self): + context._context = None + ops.enable_eager_execution_internal() + physical_devices = tf.config.experimental.list_physical_devices("CPU") + tf.config.experimental.set_virtual_device_configuration( + physical_devices[0], + [tf.config.experimental.VirtualDeviceConfiguration(), + tf.config.experimental.VirtualDeviceConfiguration()]) + + devices = tf.config.experimental.list_logical_devices(device_type="CPU") + strategy = tf.distribute.MirroredStrategy(devices=[device.name for device in devices]) + + with strategy.scope(): + accumulator = GradientAccumulator() + variable = tf.Variable([4.0, 3.0]) + optimizer = create_optimizer(5e-5, 10, 5) + gradient_placeholder = tf.Variable([0.0, 0.0], trainable=False) + + def accumulate_on_replica(gradient): + accumulator([gradient]) + + def apply_on_replica(): + optimizer.apply_gradients(list(zip(accumulator.gradients, [variable])), 1.0) + + @tf.function + def accumulate(grad1, grad2): + with strategy.scope(): + gradient_placeholder.values[0].assign(grad1) + gradient_placeholder.values[1].assign(grad2) + strategy.experimental_run_v2(accumulate_on_replica, args=(gradient_placeholder,)) + + @tf.function + def apply_grad(): + with strategy.scope(): + strategy.experimental_run_v2(apply_on_replica) + + accumulate([1.0, 2.0], [-1.0, 1.0]) + accumulate([3.0, -1.0], [-1.0, -1.0]) + accumulate([-2.0, 2.0], [3.0, -2.0]) + self.assertEqual(accumulator.step, 3) + self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [2.0, 3.0], tol=1e-2) + self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [1.0, -2.0], tol=1e-2) + apply_grad() + self.assertListAlmostEqual(variable.value().numpy().tolist(), [4.0, 3.0], tol=1e-2) + accumulator.reset() + self.assertEqual(accumulator.step, 0) + self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [0.0, 0.0], tol=1e-2) + self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [0.0, 0.0], tol=1e-2) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file