Fix TFSwinSelfAttention to have relative position index as non-trainable weight (#18226)

Signed-off-by: Seunghwan Hong <seunghwan@scatterlab.co.kr>
This commit is contained in:
Seunghwan Hong 2022-08-05 20:39:40 +09:00 committed by GitHub
parent 586dcf6b21
commit 575aa6ef1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -461,21 +461,6 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
)
# get pair-wise relative position index for each token inside the window
coords_h = tf.range(self.window_size[0])
coords_w = tf.range(self.window_size[1])
coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij"))
coords_flatten = tf.reshape(coords, (shape_list(coords)[0], -1))
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = tf.transpose(relative_coords, (1, 2, 0))
stack_0, stack_1 = tf.unstack(relative_coords, axis=2)
stack_0 += self.window_size[0] - 1
stack_0 *= 2 * self.window_size[1] - 1
stack_1 += self.window_size[1] - 1
relative_coords = tf.stack([stack_0, stack_1], axis=2)
self.relative_position_index = tf.reduce_sum(relative_coords, axis=-1)
self.query = tf.keras.layers.Dense(
self.all_head_size,
kernel_initializer=get_initializer(config.initializer_range),
@ -503,6 +488,28 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
initializer="zeros",
name="relative_position_bias_table",
)
self.relative_position_index = self.add_weight(
shape=(self.window_size[0] ** 2, self.window_size[1] ** 2),
trainable=False,
dtype=tf.int32,
name="relative_position_index",
)
# get pair-wise relative position index for each token inside the window
coords_h = tf.range(self.window_size[0])
coords_w = tf.range(self.window_size[1])
coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij"))
coords_flatten = tf.reshape(coords, (shape_list(coords)[0], -1))
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = tf.transpose(relative_coords, (1, 2, 0))
stack_0, stack_1 = tf.unstack(relative_coords, axis=2)
stack_0 += self.window_size[0] - 1
stack_0 *= 2 * self.window_size[1] - 1
stack_1 += self.window_size[1] - 1
relative_coords = tf.stack([stack_0, stack_1], axis=2)
self.relative_position_index.assign(tf.cast(tf.reduce_sum(relative_coords, axis=-1), tf.int32))
super().build(input_shape)
def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor: