mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Fix tf2.4 (#9120)
* Fix tests for TF 2.4 * Remove <2.4 limitation * Add version condition * Update tests/test_optimization_tf.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update tests/test_optimization_tf.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update tests/test_optimization_tf.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
6ccea0486f
commit
ef2d4cd445
4
setup.py
4
setup.py
@ -127,8 +127,8 @@ _deps = [
|
|||||||
"sphinx-rtd-theme==0.4.3", # sphinx-rtd-theme==0.5.0 introduced big changes in the style.
|
"sphinx-rtd-theme==0.4.3", # sphinx-rtd-theme==0.5.0 introduced big changes in the style.
|
||||||
"sphinx==3.2.1",
|
"sphinx==3.2.1",
|
||||||
"starlette",
|
"starlette",
|
||||||
"tensorflow-cpu>=2.0,<2.4",
|
"tensorflow-cpu>=2.0",
|
||||||
"tensorflow>=2.0,<2.4",
|
"tensorflow>=2.0",
|
||||||
"timeout-decorator",
|
"timeout-decorator",
|
||||||
"tokenizers==0.9.4",
|
"tokenizers==0.9.4",
|
||||||
"torch>=1.0",
|
"torch>=1.0",
|
||||||
|
@ -434,14 +434,14 @@ class TFModelTesterMixin:
|
|||||||
num_labels = 2
|
num_labels = 2
|
||||||
|
|
||||||
X = tf.data.Dataset.from_tensor_slices(
|
X = tf.data.Dataset.from_tensor_slices(
|
||||||
(inputs_dict, np.random.randint(0, num_labels, (self.model_tester.batch_size, 1)))
|
(inputs_dict, np.ones((self.model_tester.batch_size, self.model_tester.seq_length, num_labels, 1)))
|
||||||
).batch(1)
|
).batch(1)
|
||||||
|
|
||||||
hidden_states = main_layer(symbolic_inputs)[0]
|
hidden_states = main_layer(symbolic_inputs)[0]
|
||||||
outputs = tf.keras.layers.Dense(num_labels, activation="softmax", name="outputs")(hidden_states)
|
outputs = tf.keras.layers.Dense(num_labels, activation="softmax", name="outputs")(hidden_states)
|
||||||
model = tf.keras.models.Model(inputs=symbolic_inputs, outputs=[outputs])
|
model = tf.keras.models.Model(inputs=symbolic_inputs, outputs=[outputs])
|
||||||
|
|
||||||
model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["acc"])
|
model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["binary_accuracy"])
|
||||||
model.fit(X, epochs=1)
|
model.fit(X, epochs=1)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
@ -14,6 +14,8 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from transformers import is_tf_available
|
from transformers import is_tf_available
|
||||||
from transformers.testing_utils import require_tf
|
from transformers.testing_utils import require_tf
|
||||||
|
|
||||||
@ -76,11 +78,17 @@ class OptimizationFTest(unittest.TestCase):
|
|||||||
local_variables = strategy.experimental_local_results(gradient_placeholder)
|
local_variables = strategy.experimental_local_results(gradient_placeholder)
|
||||||
local_variables[0].assign(grad1)
|
local_variables[0].assign(grad1)
|
||||||
local_variables[1].assign(grad2)
|
local_variables[1].assign(grad2)
|
||||||
|
if version.parse(tf.version.VERSION) >= version.parse("2.2"):
|
||||||
|
strategy.run(accumulate_on_replica, args=(gradient_placeholder,))
|
||||||
|
else:
|
||||||
strategy.experimental_run_v2(accumulate_on_replica, args=(gradient_placeholder,))
|
strategy.experimental_run_v2(accumulate_on_replica, args=(gradient_placeholder,))
|
||||||
|
|
||||||
@tf.function
|
@tf.function
|
||||||
def apply_grad():
|
def apply_grad():
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
|
if version.parse(tf.version.VERSION) >= version.parse("2.2"):
|
||||||
|
strategy.run(apply_on_replica)
|
||||||
|
else:
|
||||||
strategy.experimental_run_v2(apply_on_replica)
|
strategy.experimental_run_v2(apply_on_replica)
|
||||||
|
|
||||||
def _check_local_values(grad1, grad2):
|
def _check_local_values(grad1, grad2):
|
||||||
|
Loading…
Reference in New Issue
Block a user