Co-authored-by: susnato <susnato@tensorflow123456@gmail.com>
This commit is contained in:
Susnato Dhar 2023-01-16 19:36:35 +05:30 committed by GitHub
parent 488a179ce1
commit a5327c6a9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -446,7 +446,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.vocab_size, dtype=input_ids.dtype),
tf.cast(self.config.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.vocab_size})"