diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 5149ff39c84..90491d6aa84 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -679,7 +679,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): inputs = tf.where(indices_replaced, mask_token_id, inputs) # 10% of the time, we replace masked input tokens with random word - indices_random = self.tf_bernoulli(input_shape, 0.1) & masked_indices & ~indices_replaced + indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=tf.int64) inputs = tf.where(indices_random, random_words, inputs) @@ -1062,7 +1062,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): inputs = tf.where(indices_replaced, self.tokenizer.mask_token_id, inputs) # 10% of the time, we replace masked input tokens with random word - indices_random = self.tf_bernoulli(input_shape, 0.1) & masked_indices & ~indices_replaced + indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64) inputs = tf.where(indices_random, random_words, inputs)