From a1cf9f33908a60ba74366c575e0e6942b588be69 Mon Sep 17 00:00:00 2001 From: gautham <91133513+capemox@users.noreply.github.com> Date: Fri, 7 Mar 2025 19:39:27 +0530 Subject: [PATCH] Fixed datatype related issues in `DataCollatorForLanguageModeling` (#36457) Fixed 2 issues regarding `tests/trainer/test_data_collator.py::TFDataCollatorIntegrationTest::test_all_mask_replacement`: 1. I got the error `RuntimeError: "bernoulli_tensor_cpu_p_" not implemented for 'Long'`. This is because the `mask_replacement_prob=1` and `torch.bernoulli` doesn't accept this type (which would be a `torch.long` dtype instead. I fixed this by manually casting the probability arguments in the `__post_init__` function of `DataCollatorForLanguageModeling`. 2. I also got the error `tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute Equal as input #1(zero-based) was expected to be a int64 tensor but is a int32 tensor [Op:Equal]` due to the line `tf.reduce_all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id))` in `test_data_collator.py`. This occurs because the type of the `inputs` variable is `tf.int32`. Solved this by manually casting it to `tf.int64` in the test, as the expected return type of `batch["input_ids"]` is `tf.int64`. --- src/transformers/data/data_collator.py | 4 ++++ tests/trainer/test_data_collator.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 4af7d609f03..1de84685712 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -843,6 +843,10 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): if self.random_replace_prob < 0 or self.random_replace_prob > 1: raise ValueError("random_replace_prob should be between 0 and 1.") + self.mlm_probability = float(self.mlm_probability) + self.mask_replace_prob = float(self.mask_replace_prob) + self.random_replace_prob = float(self.random_replace_prob) + if self.tf_experimental_compile: import tensorflow as tf diff --git a/tests/trainer/test_data_collator.py b/tests/trainer/test_data_collator.py index c3e9b5a3bad..d631299c01f 100644 --- a/tests/trainer/test_data_collator.py +++ b/tests/trainer/test_data_collator.py @@ -1052,7 +1052,9 @@ class TFDataCollatorIntegrationTest(unittest.TestCase): # confirm that every token is either the original token or [MASK] self.assertTrue( - tf.reduce_all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id)) + tf.reduce_all( + (batch["input_ids"] == tf.cast(inputs, tf.int64)) | (batch["input_ids"] == tokenizer.mask_token_id) + ) ) # numpy call