TF: XLA-trainable DeBERTa v2 (#18546)

* fix deberta issues

* add different code paths for gpu and tpu

* shorter gpu take along axis

* Stable Dropout without tf cond

* variable must be float
This commit is contained in:
Joao Gante 2022-08-10 12:57:21 +01:00 committed by GitHub
parent 4a51075a96
commit 34aad0dac0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 54 deletions

View File

@ -101,27 +101,6 @@ class TFDebertaXSoftmax(tf.keras.layers.Layer):
return output
def get_mask(input, dropout):
mask = tf.cast(
1 - tf.compat.v1.distributions.Bernoulli(probs=1 - dropout).sample(sample_shape=shape_list(input)), tf.bool
)
return mask, dropout
@tf.custom_gradient
def TFDebertaXDropout(input, local_ctx):
mask, dropout = get_mask(input, local_ctx)
scale = tf.convert_to_tensor(1.0 / (1 - dropout), dtype=tf.float32)
input = tf.cond(dropout > 0, lambda: tf.where(mask, 0.0, input) * scale, lambda: input)
def custom_grad(upstream_grad):
return tf.cond(
scale > 1, lambda: (tf.where(mask, 0.0, upstream_grad) * scale, None), lambda: (upstream_grad, None)
)
return input, custom_grad
class TFDebertaStableDropout(tf.keras.layers.Layer):
"""
Optimized dropout module for stabilizing the training
@ -132,11 +111,33 @@ class TFDebertaStableDropout(tf.keras.layers.Layer):
def __init__(self, drop_prob, **kwargs):
super().__init__(**kwargs)
self.drop_prob = tf.convert_to_tensor(drop_prob, dtype=tf.float32)
self.drop_prob = drop_prob
@tf.custom_gradient
def xdropout(self, inputs):
"""
Applies dropout to the inputs, as vanilla dropout, but also scales the remaining elements up by 1/drop_prob.
"""
mask = tf.cast(
1
- tf.compat.v1.distributions.Bernoulli(probs=1.0 - self.drop_prob).sample(sample_shape=shape_list(inputs)),
tf.bool,
)
scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=tf.float32)
if self.drop_prob > 0:
inputs = tf.where(mask, 0.0, inputs) * scale
def grad(upstream):
if self.drop_prob > 0:
return tf.where(mask, 0.0, upstream) * scale
else:
return upstream
return inputs, grad
def call(self, inputs: tf.Tensor, training: tf.Tensor = False):
if training and self.drop_prob > 0:
return TFDebertaXDropout(inputs, self.drop_prob)
if training:
return self.xdropout(inputs)
return inputs

View File

@ -102,29 +102,6 @@ class TFDebertaV2XSoftmax(tf.keras.layers.Layer):
return output
# Copied from transformers.models.deberta.modeling_tf_deberta.get_mask
def get_mask(input, dropout):
mask = tf.cast(
1 - tf.compat.v1.distributions.Bernoulli(probs=1 - dropout).sample(sample_shape=shape_list(input)), tf.bool
)
return mask, dropout
@tf.custom_gradient
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaXDropout
def TFDebertaV2XDropout(input, local_ctx):
mask, dropout = get_mask(input, local_ctx)
scale = tf.convert_to_tensor(1.0 / (1 - dropout), dtype=tf.float32)
input = tf.cond(dropout > 0, lambda: tf.where(mask, 0.0, input) * scale, lambda: input)
def custom_grad(upstream_grad):
return tf.cond(
scale > 1, lambda: (tf.where(mask, 0.0, upstream_grad) * scale, None), lambda: (upstream_grad, None)
)
return input, custom_grad
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaStableDropout with Deberta->DebertaV2
class TFDebertaV2StableDropout(tf.keras.layers.Layer):
"""
@ -136,11 +113,33 @@ class TFDebertaV2StableDropout(tf.keras.layers.Layer):
def __init__(self, drop_prob, **kwargs):
super().__init__(**kwargs)
self.drop_prob = tf.convert_to_tensor(drop_prob, dtype=tf.float32)
self.drop_prob = drop_prob
@tf.custom_gradient
def xdropout(self, inputs):
"""
Applies dropout to the inputs, as vanilla dropout, but also scales the remaining elements up by 1/drop_prob.
"""
mask = tf.cast(
1
- tf.compat.v1.distributions.Bernoulli(probs=1.0 - self.drop_prob).sample(sample_shape=shape_list(inputs)),
tf.bool,
)
scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=tf.float32)
if self.drop_prob > 0:
inputs = tf.where(mask, 0.0, inputs) * scale
def grad(upstream):
if self.drop_prob > 0:
return tf.where(mask, 0.0, upstream) * scale
else:
return upstream
return inputs, grad
def call(self, inputs: tf.Tensor, training: tf.Tensor = False):
if training and self.drop_prob > 0:
return TFDebertaV2XDropout(inputs, self.drop_prob)
if training:
return self.xdropout(inputs)
return inputs
@ -525,10 +524,18 @@ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
def take_along_axis(x, indices):
# Only a valid port of np.take_along_axis when the gather axis is -1
flat_x = tf.reshape(x, (-1, x.shape[-1]))
flat_indices = tf.reshape(indices, (-1, indices.shape[-1]))
gathered = tf.gather(flat_x, flat_indices, batch_dims=1)
gathered = tf.reshape(gathered, indices.shape)
# TPU + gathers and reshapes don't go along well -- see https://github.com/huggingface/transformers/issues/18239
if isinstance(tf.distribute.get_strategy(), tf.distribute.TPUStrategy):
# [B, S, P] -> [B, S, P, D]
one_hot_indices = tf.one_hot(indices, depth=x.shape[-1], dtype=x.dtype)
# if we ignore the first two dims, this is equivalent to multiplying a matrix (one hot) by a vector (x)
# grossly abusing notation: [B, S, P, D] . [B, S, D] = [B, S, P]
gathered = tf.einsum("ijkl,ijl->ijk", one_hot_indices, x)
# GPUs, on the other hand, prefer gathers instead of large one-hot+matmuls
else:
gathered = tf.gather(x, indices, batch_dims=2)
return gathered