transformers/tests/test_optimization_tf.py
Julien Plu f9414f7553
Tensorflow improvements (#4530)
* Better None gradients handling

* Apply Style

* Apply Style

* Create a loss class per task to compute its respective loss

* Add loss classes to the ALBERT TF models

* Add loss classes to the BERT TF models

* Add question answering and multiple choice to TF Camembert

* Remove prints

* Add multiple choice model to TF DistilBERT + loss computation

* Add question answering model to TF Electra + loss computation

* Add token classification, question answering and multiple choice models to TF Flaubert

* Add multiple choice model to TF Roberta + loss computation

* Add multiple choice model to TF XLM + loss computation

* Add multiple choice and question answering models to TF XLM-Roberta

* Add multiple choice model to TF XLNet + loss computation

* Remove unused parameters

* Add task loss classes

* Reorder TF imports + add new model classes

* Add new model classes

* Bugfix in TF T5 model

* Bugfix for TF T5 tests

* Bugfix in TF T5 model

* Fix TF T5 model tests

* Fix T5 tests + some renaming

* Fix inheritance issue in the AutoX tests

* Add tests for TF Flaubert and TF XLM Roberta

* Add tests for TF Flaubert and TF XLM Roberta

* Remove unused piece of code in the TF trainer

* bugfix and remove unused code

* Bugfix for TF 2.2

* Apply Style

* Divide TFSequenceClassificationAndMultipleChoiceLoss into their two respective name

* Apply style

* Mirror the PT Trainer in the TF one: fp16, optimizers and tb_writer as class parameter and better dataset handling

* Fix TF optimizations tests and apply style

* Remove useless parameter

* Bugfix and apply style

* Fix TF Trainer prediction

* Now the TF models return the loss such as their PyTorch couterparts

* Apply Style

* Ignore some tests output

* Take into account the SQuAD cls_index, p_mask and is_impossible parameters for the QuestionAnswering task models.

* Fix names for SQuAD data

* Apply Style

* Fix conflicts with 2.11 release

* Fix conflicts with 2.11

* Fix wrongname

* Add better documentation on the new create_optimizer function

* Fix isort

* logging_dir: use same default as PyTorch

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
2020-06-04 19:45:53 -04:00

87 lines
3.5 KiB
Python

import unittest
from transformers import is_tf_available
from .utils import require_tf
if is_tf_available():
import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from transformers import create_optimizer, GradientAccumulator
@require_tf
class OptimizationFTest(unittest.TestCase):
def assertListAlmostEqual(self, list1, list2, tol):
self.assertEqual(len(list1), len(list2))
for a, b in zip(list1, list2):
self.assertAlmostEqual(a, b, delta=tol)
def testGradientAccumulator(self):
accumulator = GradientAccumulator()
accumulator([tf.constant([1.0, 2.0])])
accumulator([tf.constant([-2.0, 1.0])])
accumulator([tf.constant([-1.0, 2.0])])
with self.assertRaises(ValueError):
accumulator([tf.constant([1.0, 1.0]), tf.constant([2.0, 2.0])])
self.assertEqual(accumulator.step, 3)
self.assertEqual(len(accumulator.gradients), 1)
self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [-2.0, 5.0], tol=1e-2)
accumulator.reset()
self.assertEqual(accumulator.step, 0)
self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [0.0, 0.0], tol=1e-2)
def testGradientAccumulatorDistributionStrategy(self):
context._context = None
ops.enable_eager_execution_internal()
physical_devices = tf.config.list_physical_devices("CPU")
if len(physical_devices) == 1:
tf.config.set_logical_device_configuration(
physical_devices[0], [tf.config.LogicalDeviceConfiguration(), tf.config.LogicalDeviceConfiguration()]
)
devices = tf.config.list_logical_devices(device_type="CPU")
strategy = tf.distribute.MirroredStrategy(devices=devices[:2])
with strategy.scope():
accumulator = GradientAccumulator()
variable = tf.Variable([4.0, 3.0])
optimizer, _ = create_optimizer(5e-5, 10, 5)
gradient_placeholder = tf.Variable([0.0, 0.0], trainable=False)
def accumulate_on_replica(gradient):
accumulator([gradient])
def apply_on_replica():
optimizer.apply_gradients(list(zip(accumulator.gradients, [variable])))
@tf.function
def accumulate(grad1, grad2):
with strategy.scope():
local_variables = strategy.experimental_local_results(gradient_placeholder)
local_variables[0].assign(grad1)
local_variables[1].assign(grad2)
strategy.experimental_run_v2(accumulate_on_replica, args=(gradient_placeholder,))
@tf.function
def apply_grad():
with strategy.scope():
strategy.experimental_run_v2(apply_on_replica)
def _check_local_values(grad1, grad2):
values = strategy.experimental_local_results(accumulator._gradients[0])
self.assertListAlmostEqual(values[0].value(), grad1, tol=1e-2)
self.assertListAlmostEqual(values[1].value(), grad2, tol=1e-2)
accumulate([1.0, 2.0], [-1.0, 1.0])
accumulate([3.0, -1.0], [-1.0, -1.0])
accumulate([-2.0, 2.0], [3.0, -2.0])
self.assertEqual(accumulator.step, 3)
_check_local_values([2.0, 3.0], [1.0, -2.0])
apply_grad()
self.assertListAlmostEqual(variable.value(), [4.0, 3.0], tol=1e-2)
accumulator.reset()
self.assertEqual(accumulator.step, 0)
_check_local_values([0.0, 0.0], [0.0, 0.0])