mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
d688af19e5
commit
23b7138ab4
@ -226,8 +226,9 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
dim_per_head = self.dim // self.n_heads
|
||||
|
||||
assert 2 <= len(tf.shape(mask)) <= 3
|
||||
causal = (len(tf.shape(mask)) == 3)
|
||||
mask_shape = shape_list(mask)
|
||||
assert 2 <= len(mask_shape) <= 3
|
||||
causal = (mask_shape) == 3)
|
||||
mask_reshape = [bs, 1, 1, k_length]
|
||||
|
||||
def shape(x):
|
||||
|
Loading…
Reference in New Issue
Block a user