diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 55aed55a13f..079fb13c2a8 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -824,7 +824,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): tokenizer: PreTrainedTokenizerBase mlm: bool = True - mlm_probability: float = 0.15 + mlm_probability: Optional[float] = 0.15 mask_replace_prob: float = 0.8 random_replace_prob: float = 0.1 pad_to_multiple_of: Optional[int] = None @@ -833,13 +833,15 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): seed: Optional[int] = None def __post_init__(self): - if self.mlm and self.tokenizer.mask_token is None: - raise ValueError( - "This tokenizer does not have a mask token which is necessary for masked language modeling. " - "You should pass `mlm=False` to train on causal language modeling instead." - ) - if self.mlm_probability < 0 or self.mlm_probability > 1: - raise ValueError("mlm_probability should be between 0 and 1.") + if self.mlm: + if self.tokenizer.mask_token is None: + raise ValueError( + "This tokenizer does not have a mask token which is necessary for masked language modeling. " + "You should pass `mlm=False` to train on causal language modeling instead." + ) + if self.mlm_probability is None or self.mlm_probability < 0 or self.mlm_probability > 1: + raise ValueError("mlm_probability should be between 0 and 1.") + self.mlm_probability = float(self.mlm_probability) if self.mask_replace_prob + self.random_replace_prob > 1: raise ValueError("The sum of mask_replace_prob and random_replace_prob should not exceed 1") if self.mask_replace_prob < 0 or self.mask_replace_prob > 1: @@ -847,7 +849,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): if self.random_replace_prob < 0 or self.random_replace_prob > 1: raise ValueError("random_replace_prob should be between 0 and 1.") - self.mlm_probability = float(self.mlm_probability) self.mask_replace_prob = float(self.mask_replace_prob) self.random_replace_prob = float(self.random_replace_prob)