mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Change how take_along_axis
is computed in DeBERTa to stop confusing XLA (#18256)
* Change how `take_along_axis` is computed in DeBERTa to stop confusing XLA * Greatly simplify take_along_axis() since the code wasn't using most of it
This commit is contained in:
parent
d95a32cc60
commit
07505358ba
@ -522,26 +522,13 @@ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
|
||||
return tf.broadcast_to(pos_index, shapes)
|
||||
|
||||
|
||||
def take_along_axis(x, indices, gather_axis):
|
||||
if gather_axis < 0:
|
||||
gather_axis = tf.rank(x) + gather_axis
|
||||
def take_along_axis(x, indices):
|
||||
# Only a valid port of np.take_along_axis when the gather axis is -1
|
||||
|
||||
if gather_axis != tf.rank(x) - 1:
|
||||
pre_roll = tf.rank(x) - 1 - gather_axis
|
||||
permutation = tf.roll(tf.range(tf.rank(x)), pre_roll, axis=0)
|
||||
x = tf.transpose(x, perm=permutation)
|
||||
indices = tf.transpose(indices, perm=permutation)
|
||||
else:
|
||||
pre_roll = 0
|
||||
|
||||
flat_x = tf.reshape(x, (-1, tf.shape(x)[-1]))
|
||||
flat_indices = tf.reshape(indices, (-1, tf.shape(indices)[-1]))
|
||||
flat_x = tf.reshape(x, (-1, x.shape[-1]))
|
||||
flat_indices = tf.reshape(indices, (-1, indices.shape[-1]))
|
||||
gathered = tf.gather(flat_x, flat_indices, batch_dims=1)
|
||||
gathered = tf.reshape(gathered, tf.shape(indices))
|
||||
|
||||
if pre_roll != 0:
|
||||
permutation = tf.roll(tf.range(tf.rank(x)), -pre_roll, axis=0)
|
||||
gathered = tf.transpose(gathered, perm=permutation)
|
||||
gathered = tf.reshape(gathered, indices.shape)
|
||||
|
||||
return gathered
|
||||
|
||||
@ -775,7 +762,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
|
||||
tf.squeeze(c2p_pos, 0),
|
||||
[shape_list(query_layer)[0], shape_list(query_layer)[1], shape_list(relative_pos)[-1]],
|
||||
),
|
||||
-1,
|
||||
)
|
||||
score += c2p_att / scale
|
||||
|
||||
@ -803,7 +789,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
|
||||
tf.squeeze(p2c_pos, 0),
|
||||
[shape_list(query_layer)[0], shape_list(key_layer)[-2], shape_list(key_layer)[-2]],
|
||||
),
|
||||
-1,
|
||||
),
|
||||
[0, 2, 1],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user