mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix TFSwinSelfAttention to have relative position index as non-trainable weight (#18226)
Signed-off-by: Seunghwan Hong <seunghwan@scatterlab.co.kr>
This commit is contained in:
parent
586dcf6b21
commit
575aa6ef1a
@ -461,21 +461,6 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
|
||||
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
|
||||
)
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = tf.range(self.window_size[0])
|
||||
coords_w = tf.range(self.window_size[1])
|
||||
coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij"))
|
||||
coords_flatten = tf.reshape(coords, (shape_list(coords)[0], -1))
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
relative_coords = tf.transpose(relative_coords, (1, 2, 0))
|
||||
|
||||
stack_0, stack_1 = tf.unstack(relative_coords, axis=2)
|
||||
stack_0 += self.window_size[0] - 1
|
||||
stack_0 *= 2 * self.window_size[1] - 1
|
||||
stack_1 += self.window_size[1] - 1
|
||||
relative_coords = tf.stack([stack_0, stack_1], axis=2)
|
||||
self.relative_position_index = tf.reduce_sum(relative_coords, axis=-1)
|
||||
|
||||
self.query = tf.keras.layers.Dense(
|
||||
self.all_head_size,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
@ -503,6 +488,28 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
|
||||
initializer="zeros",
|
||||
name="relative_position_bias_table",
|
||||
)
|
||||
self.relative_position_index = self.add_weight(
|
||||
shape=(self.window_size[0] ** 2, self.window_size[1] ** 2),
|
||||
trainable=False,
|
||||
dtype=tf.int32,
|
||||
name="relative_position_index",
|
||||
)
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = tf.range(self.window_size[0])
|
||||
coords_w = tf.range(self.window_size[1])
|
||||
coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij"))
|
||||
coords_flatten = tf.reshape(coords, (shape_list(coords)[0], -1))
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
relative_coords = tf.transpose(relative_coords, (1, 2, 0))
|
||||
|
||||
stack_0, stack_1 = tf.unstack(relative_coords, axis=2)
|
||||
stack_0 += self.window_size[0] - 1
|
||||
stack_0 *= 2 * self.window_size[1] - 1
|
||||
stack_1 += self.window_size[1] - 1
|
||||
relative_coords = tf.stack([stack_0, stack_1], axis=2)
|
||||
|
||||
self.relative_position_index.assign(tf.cast(tf.reduce_sum(relative_coords, axis=-1), tf.int32))
|
||||
super().build(input_shape)
|
||||
|
||||
def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
|
||||
|
Loading…
Reference in New Issue
Block a user