mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Add few tests on the TF optimization file with some info in the documentation. Complete the README.
This commit is contained in:
parent
ff98b041da
commit
9200a759d7
@ -5,6 +5,7 @@ The ``.optimization`` module provides:
|
||||
|
||||
- an optimizer with weight decay fixed that can be used to fine-tuned models, and
|
||||
- several schedules in the form of schedule objects that inherit from ``_LRSchedule``:
|
||||
- a gradient accumulation class to accumulate the gradients of multiple batches
|
||||
|
||||
``AdamW``
|
||||
~~~~~~~~~~~~~~~~
|
||||
@ -12,6 +13,15 @@ The ``.optimization`` module provides:
|
||||
.. autoclass:: transformers.AdamW
|
||||
:members:
|
||||
|
||||
``AdamWeightDecay``
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.AdamWeightDecay
|
||||
:members:
|
||||
|
||||
.. autofunction:: transformers.create_optimizer
|
||||
:members:
|
||||
|
||||
Schedules
|
||||
----------------------------------------------------
|
||||
|
||||
@ -49,3 +59,17 @@ Learning Rate Schedules
|
||||
.. image:: /imgs/warmup_linear_schedule.png
|
||||
:target: /imgs/warmup_linear_schedule.png
|
||||
:alt:
|
||||
|
||||
``Warmup``
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.Warmup
|
||||
:members:
|
||||
|
||||
Gradient Strategies
|
||||
----------------------------------------------------
|
||||
|
||||
``GradientAccumulator``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.GradientAccumulator
|
||||
|
@ -465,7 +465,8 @@ Training with the previously defined hyper-parameters yields the following resul
|
||||
|
||||
## Named Entity Recognition
|
||||
|
||||
Based on the script [`run_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/run_ner.py).
|
||||
Based on the scripts [`run_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/run_ner.py) for Pytorch and
|
||||
[`run_tf_ner.py`(https://github.com/huggingface/transformers/blob/master/examples/run_tf_ner.py)] for Tensorflow 2.
|
||||
This example fine-tune Bert Multilingual on GermEval 2014 (German NER).
|
||||
Details and results for the fine-tuning provided by @stefan-it.
|
||||
|
||||
@ -510,7 +511,7 @@ The GermEval 2014 dataset has much more labels than CoNLL-2002/2003 datasets, so
|
||||
cat train.txt dev.txt test.txt | cut -d " " -f 2 | grep -v "^$"| sort | uniq > labels.txt
|
||||
```
|
||||
|
||||
### Training
|
||||
### Prepare the run
|
||||
|
||||
Additional environment variables must be set:
|
||||
|
||||
@ -522,6 +523,8 @@ export SAVE_STEPS=750
|
||||
export SEED=1
|
||||
```
|
||||
|
||||
### Run the Pytorch version
|
||||
|
||||
To start training, just run:
|
||||
|
||||
```bash
|
||||
@ -542,7 +545,7 @@ python3 run_ner.py --data_dir ./ \
|
||||
|
||||
If your GPU supports half-precision training, just add the `--fp16` flag. After training, the model will be both evaluated on development and test datasets.
|
||||
|
||||
### Evaluation
|
||||
#### Evaluation
|
||||
|
||||
Evaluation on development dataset outputs the following for our example:
|
||||
|
||||
@ -564,7 +567,7 @@ On the test dataset the following results could be achieved:
|
||||
10/04/2019 00:42:42 - INFO - __main__ - recall = 0.8624150210424085
|
||||
```
|
||||
|
||||
### Comparing BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased)
|
||||
#### Comparing BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased)
|
||||
|
||||
Here is a small comparison between BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased) with the same hyperparameters as specified in the [example documentation](https://huggingface.co/transformers/examples.html#named-entity-recognition) (one run):
|
||||
|
||||
@ -574,6 +577,72 @@ Here is a small comparison between BERT (large, cased), RoBERTa (large, cased) a
|
||||
| `roberta-large` | 95.96 | 91.87
|
||||
| `distilbert-base-uncased` | 94.34 | 90.32
|
||||
|
||||
### Run the Tensorflow 2 version
|
||||
|
||||
To start training, just run:
|
||||
|
||||
```bash
|
||||
python3 run_tf_ner.py --data_dir ./ \
|
||||
--model_type bert \
|
||||
--labels ./labels.txt \
|
||||
--model_name_or_path $BERT_MODEL \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_seq_length $MAX_LENGTH \
|
||||
--num_train_epochs $NUM_EPOCHS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--save_steps $SAVE_STEPS \
|
||||
--seed $SEED \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_predict
|
||||
```
|
||||
|
||||
Such as the Pytorch version, if your GPU supports half-precision training, just add the `--fp16` flag. After training, the model will be both evaluated on development and test datasets.
|
||||
|
||||
#### Evaluation
|
||||
|
||||
Evaluation on development dataset outputs the following for our example:
|
||||
```bash
|
||||
precision recall f1-score support
|
||||
|
||||
LOCderiv 0.7619 0.6154 0.6809 52
|
||||
PERpart 0.8724 0.8997 0.8858 4057
|
||||
OTHpart 0.9360 0.9466 0.9413 711
|
||||
ORGpart 0.7015 0.6989 0.7002 269
|
||||
LOCpart 0.7668 0.8488 0.8057 496
|
||||
LOC 0.8745 0.9191 0.8963 235
|
||||
ORGderiv 0.7723 0.8571 0.8125 91
|
||||
OTHderiv 0.4800 0.6667 0.5581 18
|
||||
OTH 0.5789 0.6875 0.6286 16
|
||||
PERderiv 0.5385 0.3889 0.4516 18
|
||||
PER 0.5000 0.5000 0.5000 2
|
||||
ORG 0.0000 0.0000 0.0000 3
|
||||
|
||||
micro avg 0.8574 0.8862 0.8715 5968
|
||||
macro avg 0.8575 0.8862 0.8713 5968
|
||||
```
|
||||
|
||||
On the test dataset the following results could be achieved:
|
||||
```bash
|
||||
precision recall f1-score support
|
||||
|
||||
PERpart 0.8847 0.8944 0.8896 9397
|
||||
OTHpart 0.9376 0.9353 0.9365 1639
|
||||
ORGpart 0.7307 0.7044 0.7173 697
|
||||
LOC 0.9133 0.9394 0.9262 561
|
||||
LOCpart 0.8058 0.8157 0.8107 1150
|
||||
ORG 0.0000 0.0000 0.0000 8
|
||||
OTHderiv 0.5882 0.4762 0.5263 42
|
||||
PERderiv 0.6571 0.5227 0.5823 44
|
||||
OTH 0.4906 0.6667 0.5652 39
|
||||
ORGderiv 0.7016 0.7791 0.7383 172
|
||||
LOCderiv 0.8256 0.6514 0.7282 109
|
||||
PER 0.0000 0.0000 0.0000 11
|
||||
|
||||
micro avg 0.8722 0.8774 0.8748 13869
|
||||
macro avg 0.8712 0.8774 0.8740 13869
|
||||
```
|
||||
|
||||
## Abstractive summarization
|
||||
|
||||
Based on the script
|
||||
|
@ -540,6 +540,9 @@ def main(_):
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args['output_dir'] + "/**/" + TF2_WEIGHTS_NAME, recursive=True), key=lambda f: int(''.join(filter(str.isdigit, f)) or -1)))
|
||||
|
||||
logging.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
if len(checkpoints) == 0:
|
||||
checkpoints.append(args['output_dir'])
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"
|
||||
@ -572,10 +575,10 @@ def main(_):
|
||||
if args['do_predict']:
|
||||
tokenizer = tokenizer_class.from_pretrained(args['output_dir'], do_lower_case=args['do_lower_case'])
|
||||
model = model_class.from_pretrained(args['output_dir'])
|
||||
eval_batch_size = args['per_gpu_eval_batch_size'] * args['n_device']
|
||||
eval_batch_size = args['per_device_eval_batch_size'] * args['n_device']
|
||||
predict_dataset, _ = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test")
|
||||
y_true, y_pred, pred_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="test")
|
||||
output_test_results_file = os.path.join(args.output_dir, "test_results.txt")
|
||||
output_test_results_file = os.path.join(args['output_dir'], "test_results.txt")
|
||||
output_test_predictions_file = os.path.join(args['output_dir'], "test_predictions.txt")
|
||||
report = metrics.classification_report(y_true, y_pred, digits=4)
|
||||
|
||||
|
89
transformers/tests/optimization_tf_test.py
Normal file
89
transformers/tests/optimization_tf_test.py
Normal file
@ -0,0 +1,89 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from transformers import is_tf_available
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from transformers import (create_optimizer, GradientAccumulator)
|
||||
else:
|
||||
pytestmark = pytest.mark.skip("Require TensorFlow")
|
||||
|
||||
class OptimizationFTest(unittest.TestCase):
|
||||
def assertListAlmostEqual(self, list1, list2, tol):
|
||||
self.assertEqual(len(list1), len(list2))
|
||||
for a, b in zip(list1, list2):
|
||||
self.assertAlmostEqual(a, b, delta=tol)
|
||||
|
||||
def testGradientAccumulator(self):
|
||||
accumulator = GradientAccumulator()
|
||||
accumulator([tf.constant([1.0, 2.0])])
|
||||
accumulator([tf.constant([-2.0, 1.0])])
|
||||
accumulator([tf.constant([-1.0, 2.0])])
|
||||
with self.assertRaises(ValueError):
|
||||
accumulator([tf.constant([1.0, 1.0]), tf.constant([2.0, 2.0])])
|
||||
self.assertEqual(accumulator.step, 3)
|
||||
self.assertEqual(len(accumulator.gradients), 1)
|
||||
self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [-2.0, 5.0], tol=1e-2)
|
||||
accumulator.reset()
|
||||
self.assertEqual(accumulator.step, 0)
|
||||
self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [0.0, 0.0], tol=1e-2)
|
||||
|
||||
def testGradientAccumulatorDistributionStrategy(self):
|
||||
context._context = None
|
||||
ops.enable_eager_execution_internal()
|
||||
physical_devices = tf.config.experimental.list_physical_devices("CPU")
|
||||
tf.config.experimental.set_virtual_device_configuration(
|
||||
physical_devices[0],
|
||||
[tf.config.experimental.VirtualDeviceConfiguration(),
|
||||
tf.config.experimental.VirtualDeviceConfiguration()])
|
||||
|
||||
devices = tf.config.experimental.list_logical_devices(device_type="CPU")
|
||||
strategy = tf.distribute.MirroredStrategy(devices=[device.name for device in devices])
|
||||
|
||||
with strategy.scope():
|
||||
accumulator = GradientAccumulator()
|
||||
variable = tf.Variable([4.0, 3.0])
|
||||
optimizer = create_optimizer(5e-5, 10, 5)
|
||||
gradient_placeholder = tf.Variable([0.0, 0.0], trainable=False)
|
||||
|
||||
def accumulate_on_replica(gradient):
|
||||
accumulator([gradient])
|
||||
|
||||
def apply_on_replica():
|
||||
optimizer.apply_gradients(list(zip(accumulator.gradients, [variable])), 1.0)
|
||||
|
||||
@tf.function
|
||||
def accumulate(grad1, grad2):
|
||||
with strategy.scope():
|
||||
gradient_placeholder.values[0].assign(grad1)
|
||||
gradient_placeholder.values[1].assign(grad2)
|
||||
strategy.experimental_run_v2(accumulate_on_replica, args=(gradient_placeholder,))
|
||||
|
||||
@tf.function
|
||||
def apply_grad():
|
||||
with strategy.scope():
|
||||
strategy.experimental_run_v2(apply_on_replica)
|
||||
|
||||
accumulate([1.0, 2.0], [-1.0, 1.0])
|
||||
accumulate([3.0, -1.0], [-1.0, -1.0])
|
||||
accumulate([-2.0, 2.0], [3.0, -2.0])
|
||||
self.assertEqual(accumulator.step, 3)
|
||||
self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [2.0, 3.0], tol=1e-2)
|
||||
self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [1.0, -2.0], tol=1e-2)
|
||||
apply_grad()
|
||||
self.assertListAlmostEqual(variable.value().numpy().tolist(), [4.0, 3.0], tol=1e-2)
|
||||
accumulator.reset()
|
||||
self.assertEqual(accumulator.step, 0)
|
||||
self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [0.0, 0.0], tol=1e-2)
|
||||
self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [0.0, 0.0], tol=1e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user