Change how take_along_axis is computed in DeBERTa to stop confusing XLA (#18256)

* Change how `take_along_axis` is computed in DeBERTa to stop confusing XLA

* Greatly simplify take_along_axis() since the code wasn't using most of it
This commit is contained in:
Matt 2022-07-22 12:01:30 -04:00 committed by GitHub
parent d95a32cc60
commit 07505358ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -522,26 +522,13 @@ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
return tf.broadcast_to(pos_index, shapes)
def take_along_axis(x, indices, gather_axis):
if gather_axis < 0:
gather_axis = tf.rank(x) + gather_axis
def take_along_axis(x, indices):
# Only a valid port of np.take_along_axis when the gather axis is -1
if gather_axis != tf.rank(x) - 1:
pre_roll = tf.rank(x) - 1 - gather_axis
permutation = tf.roll(tf.range(tf.rank(x)), pre_roll, axis=0)
x = tf.transpose(x, perm=permutation)
indices = tf.transpose(indices, perm=permutation)
else:
pre_roll = 0
flat_x = tf.reshape(x, (-1, tf.shape(x)[-1]))
flat_indices = tf.reshape(indices, (-1, tf.shape(indices)[-1]))
flat_x = tf.reshape(x, (-1, x.shape[-1]))
flat_indices = tf.reshape(indices, (-1, indices.shape[-1]))
gathered = tf.gather(flat_x, flat_indices, batch_dims=1)
gathered = tf.reshape(gathered, tf.shape(indices))
if pre_roll != 0:
permutation = tf.roll(tf.range(tf.rank(x)), -pre_roll, axis=0)
gathered = tf.transpose(gathered, perm=permutation)
gathered = tf.reshape(gathered, indices.shape)
return gathered
@ -775,7 +762,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
tf.squeeze(c2p_pos, 0),
[shape_list(query_layer)[0], shape_list(query_layer)[1], shape_list(relative_pos)[-1]],
),
-1,
)
score += c2p_att / scale
@ -803,7 +789,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
tf.squeeze(p2c_pos, 0),
[shape_list(query_layer)[0], shape_list(key_layer)[-2], shape_list(key_layer)[-2]],
),
-1,
),
[0, 2, 1],
)