This commit is contained in:
thomwolf 2019-10-09 01:54:44 +02:00
parent d688af19e5
commit 23b7138ab4

View File

@ -226,8 +226,9 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
dim_per_head = self.dim // self.n_heads
assert 2 <= len(tf.shape(mask)) <= 3
causal = (len(tf.shape(mask)) == 3)
mask_shape = shape_list(mask)
assert 2 <= len(mask_shape) <= 3
causal = (mask_shape) == 3)
mask_reshape = [bs, 1, 1, k_length]
def shape(x):