mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-20 13:08:21 +06:00
change tf.math.divide with int(/) to remove dim_per_head from the TF graph (#14600)
Co-authored-by: yis <yis@graphcore.ai>
This commit is contained in:
parent
43f953cc2e
commit
96cc02b51b
@ -170,7 +170,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
|
|||||||
k_length = shape_list(key)[1]
|
k_length = shape_list(key)[1]
|
||||||
# assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
|
# assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
|
||||||
# assert key.size() == value.size()
|
# assert key.size() == value.size()
|
||||||
dim_per_head = tf.math.divide(self.dim, self.n_heads)
|
dim_per_head = int(self.dim / self.n_heads)
|
||||||
dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
|
dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
|
||||||
mask_reshape = [bs, 1, 1, k_length]
|
mask_reshape = [bs, 1, 1, k_length]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user