Allow mlm_probability to be set to None when mlm=False in DataCollatorForLanguageModeling (#38522) (#38537)

* mlm_probability in DataCollatorForLanguageModeling should be validated only when mlm is True (#38522)

* Change mlm_probability to Optional in DataCollatorForLanguageModeling (#38537)

---------

Co-authored-by: eak <eak@ivalua.com>
This commit is contained in:
KameniAlexNea 2025-06-05 14:54:12 +02:00 committed by GitHub
parent 65f5fa71cd
commit 8f630651b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -824,7 +824,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
tokenizer: PreTrainedTokenizerBase
mlm: bool = True
mlm_probability: float = 0.15
mlm_probability: Optional[float] = 0.15
mask_replace_prob: float = 0.8
random_replace_prob: float = 0.1
pad_to_multiple_of: Optional[int] = None
@ -833,13 +833,15 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
seed: Optional[int] = None
def __post_init__(self):
if self.mlm and self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
"You should pass `mlm=False` to train on causal language modeling instead."
)
if self.mlm_probability < 0 or self.mlm_probability > 1:
raise ValueError("mlm_probability should be between 0 and 1.")
if self.mlm:
if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
"You should pass `mlm=False` to train on causal language modeling instead."
)
if self.mlm_probability is None or self.mlm_probability < 0 or self.mlm_probability > 1:
raise ValueError("mlm_probability should be between 0 and 1.")
self.mlm_probability = float(self.mlm_probability)
if self.mask_replace_prob + self.random_replace_prob > 1:
raise ValueError("The sum of mask_replace_prob and random_replace_prob should not exceed 1")
if self.mask_replace_prob < 0 or self.mask_replace_prob > 1:
@ -847,7 +849,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
if self.random_replace_prob < 0 or self.random_replace_prob > 1:
raise ValueError("random_replace_prob should be between 0 and 1.")
self.mlm_probability = float(self.mlm_probability)
self.mask_replace_prob = float(self.mask_replace_prob)
self.random_replace_prob = float(self.random_replace_prob)