From 96cc02b51b60b71bce2ca38e23f8e1f920b6c626 Mon Sep 17 00:00:00 2001 From: yis11178 <73350188+yis11178@users.noreply.github.com> Date: Thu, 2 Dec 2021 13:13:42 +0000 Subject: [PATCH] change tf.math.divide with int(/) to remove dim_per_head from the TF graph (#14600) Co-authored-by: yis --- src/transformers/models/distilbert/modeling_tf_distilbert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/distilbert/modeling_tf_distilbert.py b/src/transformers/models/distilbert/modeling_tf_distilbert.py index e997d2a72cc..172194d1925 100644 --- a/src/transformers/models/distilbert/modeling_tf_distilbert.py +++ b/src/transformers/models/distilbert/modeling_tf_distilbert.py @@ -170,7 +170,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): k_length = shape_list(key)[1] # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' # assert key.size() == value.size() - dim_per_head = tf.math.divide(self.dim, self.n_heads) + dim_per_head = int(self.dim / self.n_heads) dim_per_head = tf.cast(dim_per_head, dtype=tf.int32) mask_reshape = [bs, 1, 1, k_length]