mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
TF: XLA-trainable DeBERTa v2 (#18546)
* fix deberta issues * add different code paths for gpu and tpu * shorter gpu take along axis * Stable Dropout without tf cond * variable must be float
This commit is contained in:
parent
4a51075a96
commit
34aad0dac0
@ -101,27 +101,6 @@ class TFDebertaXSoftmax(tf.keras.layers.Layer):
|
||||
return output
|
||||
|
||||
|
||||
def get_mask(input, dropout):
|
||||
mask = tf.cast(
|
||||
1 - tf.compat.v1.distributions.Bernoulli(probs=1 - dropout).sample(sample_shape=shape_list(input)), tf.bool
|
||||
)
|
||||
return mask, dropout
|
||||
|
||||
|
||||
@tf.custom_gradient
|
||||
def TFDebertaXDropout(input, local_ctx):
|
||||
mask, dropout = get_mask(input, local_ctx)
|
||||
scale = tf.convert_to_tensor(1.0 / (1 - dropout), dtype=tf.float32)
|
||||
input = tf.cond(dropout > 0, lambda: tf.where(mask, 0.0, input) * scale, lambda: input)
|
||||
|
||||
def custom_grad(upstream_grad):
|
||||
return tf.cond(
|
||||
scale > 1, lambda: (tf.where(mask, 0.0, upstream_grad) * scale, None), lambda: (upstream_grad, None)
|
||||
)
|
||||
|
||||
return input, custom_grad
|
||||
|
||||
|
||||
class TFDebertaStableDropout(tf.keras.layers.Layer):
|
||||
"""
|
||||
Optimized dropout module for stabilizing the training
|
||||
@ -132,11 +111,33 @@ class TFDebertaStableDropout(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, drop_prob, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.drop_prob = tf.convert_to_tensor(drop_prob, dtype=tf.float32)
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
@tf.custom_gradient
|
||||
def xdropout(self, inputs):
|
||||
"""
|
||||
Applies dropout to the inputs, as vanilla dropout, but also scales the remaining elements up by 1/drop_prob.
|
||||
"""
|
||||
mask = tf.cast(
|
||||
1
|
||||
- tf.compat.v1.distributions.Bernoulli(probs=1.0 - self.drop_prob).sample(sample_shape=shape_list(inputs)),
|
||||
tf.bool,
|
||||
)
|
||||
scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=tf.float32)
|
||||
if self.drop_prob > 0:
|
||||
inputs = tf.where(mask, 0.0, inputs) * scale
|
||||
|
||||
def grad(upstream):
|
||||
if self.drop_prob > 0:
|
||||
return tf.where(mask, 0.0, upstream) * scale
|
||||
else:
|
||||
return upstream
|
||||
|
||||
return inputs, grad
|
||||
|
||||
def call(self, inputs: tf.Tensor, training: tf.Tensor = False):
|
||||
if training and self.drop_prob > 0:
|
||||
return TFDebertaXDropout(inputs, self.drop_prob)
|
||||
if training:
|
||||
return self.xdropout(inputs)
|
||||
return inputs
|
||||
|
||||
|
||||
|
@ -102,29 +102,6 @@ class TFDebertaV2XSoftmax(tf.keras.layers.Layer):
|
||||
return output
|
||||
|
||||
|
||||
# Copied from transformers.models.deberta.modeling_tf_deberta.get_mask
|
||||
def get_mask(input, dropout):
|
||||
mask = tf.cast(
|
||||
1 - tf.compat.v1.distributions.Bernoulli(probs=1 - dropout).sample(sample_shape=shape_list(input)), tf.bool
|
||||
)
|
||||
return mask, dropout
|
||||
|
||||
|
||||
@tf.custom_gradient
|
||||
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaXDropout
|
||||
def TFDebertaV2XDropout(input, local_ctx):
|
||||
mask, dropout = get_mask(input, local_ctx)
|
||||
scale = tf.convert_to_tensor(1.0 / (1 - dropout), dtype=tf.float32)
|
||||
input = tf.cond(dropout > 0, lambda: tf.where(mask, 0.0, input) * scale, lambda: input)
|
||||
|
||||
def custom_grad(upstream_grad):
|
||||
return tf.cond(
|
||||
scale > 1, lambda: (tf.where(mask, 0.0, upstream_grad) * scale, None), lambda: (upstream_grad, None)
|
||||
)
|
||||
|
||||
return input, custom_grad
|
||||
|
||||
|
||||
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaStableDropout with Deberta->DebertaV2
|
||||
class TFDebertaV2StableDropout(tf.keras.layers.Layer):
|
||||
"""
|
||||
@ -136,11 +113,33 @@ class TFDebertaV2StableDropout(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, drop_prob, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.drop_prob = tf.convert_to_tensor(drop_prob, dtype=tf.float32)
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
@tf.custom_gradient
|
||||
def xdropout(self, inputs):
|
||||
"""
|
||||
Applies dropout to the inputs, as vanilla dropout, but also scales the remaining elements up by 1/drop_prob.
|
||||
"""
|
||||
mask = tf.cast(
|
||||
1
|
||||
- tf.compat.v1.distributions.Bernoulli(probs=1.0 - self.drop_prob).sample(sample_shape=shape_list(inputs)),
|
||||
tf.bool,
|
||||
)
|
||||
scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=tf.float32)
|
||||
if self.drop_prob > 0:
|
||||
inputs = tf.where(mask, 0.0, inputs) * scale
|
||||
|
||||
def grad(upstream):
|
||||
if self.drop_prob > 0:
|
||||
return tf.where(mask, 0.0, upstream) * scale
|
||||
else:
|
||||
return upstream
|
||||
|
||||
return inputs, grad
|
||||
|
||||
def call(self, inputs: tf.Tensor, training: tf.Tensor = False):
|
||||
if training and self.drop_prob > 0:
|
||||
return TFDebertaV2XDropout(inputs, self.drop_prob)
|
||||
if training:
|
||||
return self.xdropout(inputs)
|
||||
return inputs
|
||||
|
||||
|
||||
@ -525,10 +524,18 @@ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
|
||||
def take_along_axis(x, indices):
|
||||
# Only a valid port of np.take_along_axis when the gather axis is -1
|
||||
|
||||
flat_x = tf.reshape(x, (-1, x.shape[-1]))
|
||||
flat_indices = tf.reshape(indices, (-1, indices.shape[-1]))
|
||||
gathered = tf.gather(flat_x, flat_indices, batch_dims=1)
|
||||
gathered = tf.reshape(gathered, indices.shape)
|
||||
# TPU + gathers and reshapes don't go along well -- see https://github.com/huggingface/transformers/issues/18239
|
||||
if isinstance(tf.distribute.get_strategy(), tf.distribute.TPUStrategy):
|
||||
# [B, S, P] -> [B, S, P, D]
|
||||
one_hot_indices = tf.one_hot(indices, depth=x.shape[-1], dtype=x.dtype)
|
||||
|
||||
# if we ignore the first two dims, this is equivalent to multiplying a matrix (one hot) by a vector (x)
|
||||
# grossly abusing notation: [B, S, P, D] . [B, S, D] = [B, S, P]
|
||||
gathered = tf.einsum("ijkl,ijl->ijk", one_hot_indices, x)
|
||||
|
||||
# GPUs, on the other hand, prefer gathers instead of large one-hot+matmuls
|
||||
else:
|
||||
gathered = tf.gather(x, indices, batch_dims=2)
|
||||
|
||||
return gathered
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user