diff --git a/transformers/modeling_tf_distilbert.py b/transformers/modeling_tf_distilbert.py index 6ed28445670..fa2dc674af3 100644 --- a/transformers/modeling_tf_distilbert.py +++ b/transformers/modeling_tf_distilbert.py @@ -226,8 +226,9 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): dim_per_head = self.dim // self.n_heads - assert 2 <= len(tf.shape(mask)) <= 3 - causal = (len(tf.shape(mask)) == 3) + mask_shape = shape_list(mask) + assert 2 <= len(mask_shape) <= 3 + causal = (mask_shape) == 3) mask_reshape = [bs, 1, 1, k_length] def shape(x):