mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Allow mlm_probability
to be set to None
when mlm=False
in DataCollatorForLanguageModeling (#38522) (#38537)
* mlm_probability in DataCollatorForLanguageModeling should be validated only when mlm is True (#38522) * Change mlm_probability to Optional in DataCollatorForLanguageModeling (#38537) --------- Co-authored-by: eak <eak@ivalua.com>
This commit is contained in:
parent
65f5fa71cd
commit
8f630651b0
@ -824,7 +824,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
|||||||
|
|
||||||
tokenizer: PreTrainedTokenizerBase
|
tokenizer: PreTrainedTokenizerBase
|
||||||
mlm: bool = True
|
mlm: bool = True
|
||||||
mlm_probability: float = 0.15
|
mlm_probability: Optional[float] = 0.15
|
||||||
mask_replace_prob: float = 0.8
|
mask_replace_prob: float = 0.8
|
||||||
random_replace_prob: float = 0.1
|
random_replace_prob: float = 0.1
|
||||||
pad_to_multiple_of: Optional[int] = None
|
pad_to_multiple_of: Optional[int] = None
|
||||||
@ -833,13 +833,15 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
|||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.mlm and self.tokenizer.mask_token is None:
|
if self.mlm:
|
||||||
raise ValueError(
|
if self.tokenizer.mask_token is None:
|
||||||
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
raise ValueError(
|
||||||
"You should pass `mlm=False` to train on causal language modeling instead."
|
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
||||||
)
|
"You should pass `mlm=False` to train on causal language modeling instead."
|
||||||
if self.mlm_probability < 0 or self.mlm_probability > 1:
|
)
|
||||||
raise ValueError("mlm_probability should be between 0 and 1.")
|
if self.mlm_probability is None or self.mlm_probability < 0 or self.mlm_probability > 1:
|
||||||
|
raise ValueError("mlm_probability should be between 0 and 1.")
|
||||||
|
self.mlm_probability = float(self.mlm_probability)
|
||||||
if self.mask_replace_prob + self.random_replace_prob > 1:
|
if self.mask_replace_prob + self.random_replace_prob > 1:
|
||||||
raise ValueError("The sum of mask_replace_prob and random_replace_prob should not exceed 1")
|
raise ValueError("The sum of mask_replace_prob and random_replace_prob should not exceed 1")
|
||||||
if self.mask_replace_prob < 0 or self.mask_replace_prob > 1:
|
if self.mask_replace_prob < 0 or self.mask_replace_prob > 1:
|
||||||
@ -847,7 +849,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
|||||||
if self.random_replace_prob < 0 or self.random_replace_prob > 1:
|
if self.random_replace_prob < 0 or self.random_replace_prob > 1:
|
||||||
raise ValueError("random_replace_prob should be between 0 and 1.")
|
raise ValueError("random_replace_prob should be between 0 and 1.")
|
||||||
|
|
||||||
self.mlm_probability = float(self.mlm_probability)
|
|
||||||
self.mask_replace_prob = float(self.mask_replace_prob)
|
self.mask_replace_prob = float(self.mask_replace_prob)
|
||||||
self.random_replace_prob = float(self.random_replace_prob)
|
self.random_replace_prob = float(self.random_replace_prob)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user