Fix tf random token masking probability in data collator (#21834)

* fix tf random mask tokens probability

* fix tf random mask tokens probability in collator for langauge modelling
This commit is contained in:
anruijian 2023-02-28 07:55:47 -05:00 committed by GitHub
parent 4fe744f528
commit 2d506ea4c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -679,7 +679,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
inputs = tf.where(indices_replaced, mask_token_id, inputs) inputs = tf.where(indices_replaced, mask_token_id, inputs)
# 10% of the time, we replace masked input tokens with random word # 10% of the time, we replace masked input tokens with random word
indices_random = self.tf_bernoulli(input_shape, 0.1) & masked_indices & ~indices_replaced indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=tf.int64) random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=tf.int64)
inputs = tf.where(indices_random, random_words, inputs) inputs = tf.where(indices_random, random_words, inputs)
@ -1062,7 +1062,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
inputs = tf.where(indices_replaced, self.tokenizer.mask_token_id, inputs) inputs = tf.where(indices_replaced, self.tokenizer.mask_token_id, inputs)
# 10% of the time, we replace masked input tokens with random word # 10% of the time, we replace masked input tokens with random word
indices_random = self.tf_bernoulli(input_shape, 0.1) & masked_indices & ~indices_replaced indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64) random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)
inputs = tf.where(indices_random, random_words, inputs) inputs = tf.where(indices_random, random_words, inputs)