Change how take_along_axis is computed in DeBERTa to stop confusing XLA (#18256)

* Change how `take_along_axis` is computed in DeBERTa to stop confusing XLA

* Greatly simplify take_along_axis() since the code wasn't using most of it
This commit is contained in:
Matt 2022-07-22 12:01:30 -04:00 committed by GitHub
parent d95a32cc60
commit 07505358ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -522,26 +522,13 @@ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
return tf.broadcast_to(pos_index, shapes) return tf.broadcast_to(pos_index, shapes)
def take_along_axis(x, indices, gather_axis): def take_along_axis(x, indices):
if gather_axis < 0: # Only a valid port of np.take_along_axis when the gather axis is -1
gather_axis = tf.rank(x) + gather_axis
if gather_axis != tf.rank(x) - 1: flat_x = tf.reshape(x, (-1, x.shape[-1]))
pre_roll = tf.rank(x) - 1 - gather_axis flat_indices = tf.reshape(indices, (-1, indices.shape[-1]))
permutation = tf.roll(tf.range(tf.rank(x)), pre_roll, axis=0)
x = tf.transpose(x, perm=permutation)
indices = tf.transpose(indices, perm=permutation)
else:
pre_roll = 0
flat_x = tf.reshape(x, (-1, tf.shape(x)[-1]))
flat_indices = tf.reshape(indices, (-1, tf.shape(indices)[-1]))
gathered = tf.gather(flat_x, flat_indices, batch_dims=1) gathered = tf.gather(flat_x, flat_indices, batch_dims=1)
gathered = tf.reshape(gathered, tf.shape(indices)) gathered = tf.reshape(gathered, indices.shape)
if pre_roll != 0:
permutation = tf.roll(tf.range(tf.rank(x)), -pre_roll, axis=0)
gathered = tf.transpose(gathered, perm=permutation)
return gathered return gathered
@ -775,7 +762,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
tf.squeeze(c2p_pos, 0), tf.squeeze(c2p_pos, 0),
[shape_list(query_layer)[0], shape_list(query_layer)[1], shape_list(relative_pos)[-1]], [shape_list(query_layer)[0], shape_list(query_layer)[1], shape_list(relative_pos)[-1]],
), ),
-1,
) )
score += c2p_att / scale score += c2p_att / scale
@ -803,7 +789,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
tf.squeeze(p2c_pos, 0), tf.squeeze(p2c_pos, 0),
[shape_list(query_layer)[0], shape_list(key_layer)[-2], shape_list(key_layer)[-2]], [shape_list(query_layer)[0], shape_list(key_layer)[-2], shape_list(key_layer)[-2]],
), ),
-1,
), ),
[0, 2, 1], [0, 2, 1],
) )