mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Enhance DataCollatorForLanguageModeling with Configurable Token Replacement Probabilities (#35251)
* DataCollatorForLanguageModeling class was updated with new parameters that provides more control over the token masking and relacing * DataCollatorForLanguageModeling class was updated with new parameters that provides more control over the token masking and relacing * Addressed review comments, modified the docstring and made a test for the DataCollatorForLanguageModeling
This commit is contained in:
parent
b0cdbd9119
commit
c61fcde910
@ -691,11 +691,17 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
|||||||
tokens and the value to predict for the masked token.
|
tokens and the value to predict for the masked token.
|
||||||
mlm_probability (`float`, *optional*, defaults to 0.15):
|
mlm_probability (`float`, *optional*, defaults to 0.15):
|
||||||
The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`.
|
The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`.
|
||||||
|
mask_replace_prob (`float`, *optional*, defaults to 0.8):
|
||||||
|
The probability with which masked tokens are replaced by the tokenizer's mask token (e.g., `[MASK]`).
|
||||||
|
Defaults to 0.8, meaning 80% of the masked tokens will be replaced with `[MASK]`.
|
||||||
|
Only works when `mlm` is set to `True`.
|
||||||
|
random_replace_prob (`float`, *optional*, defaults to 0.1):
|
||||||
|
The probability with which masked tokens are replaced by random tokens from the tokenizer's vocabulary.
|
||||||
|
Defaults to 0.1, meaning 10% of the masked tokens will be replaced with random tokens. The remaining
|
||||||
|
masked tokens (1 - mask_replace_prob - random_replace_prob) are left unchanged.
|
||||||
|
Only works when `mlm` is set to `True`.
|
||||||
pad_to_multiple_of (`int`, *optional*):
|
pad_to_multiple_of (`int`, *optional*):
|
||||||
If set will pad the sequence to a multiple of the provided value.
|
If set, will pad the sequence to a multiple of the provided value.
|
||||||
|
|
||||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
|
||||||
7.0 (Volta).
|
|
||||||
return_tensors (`str`):
|
return_tensors (`str`):
|
||||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||||
|
|
||||||
@ -705,11 +711,36 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
|||||||
BatchEncoding, with the `"special_tokens_mask"` key, as returned by a [`PreTrainedTokenizer`] or a
|
BatchEncoding, with the `"special_tokens_mask"` key, as returned by a [`PreTrainedTokenizer`] or a
|
||||||
[`PreTrainedTokenizerFast`] with the argument `return_special_tokens_mask=True`.
|
[`PreTrainedTokenizerFast`] with the argument `return_special_tokens_mask=True`.
|
||||||
|
|
||||||
</Tip>"""
|
<Example Options and Expectations>
|
||||||
|
|
||||||
|
1. Default Behavior:
|
||||||
|
- `mask_replace_prob=0.8`, `random_replace_prob=0.1`.
|
||||||
|
- Expect 80% of masked tokens replaced with `[MASK]`, 10% replaced with random tokens, and 10% left unchanged.
|
||||||
|
|
||||||
|
2. All masked tokens replaced by `[MASK]`:
|
||||||
|
- `mask_replace_prob=1.0`, `random_replace_prob=0.0`.
|
||||||
|
- Expect all masked tokens to be replaced with `[MASK]`. No tokens are left unchanged or replaced with random tokens.
|
||||||
|
|
||||||
|
3. No `[MASK]` replacement, only random tokens:
|
||||||
|
- `mask_replace_prob=0.0`, `random_replace_prob=1.0`.
|
||||||
|
- Expect all masked tokens to be replaced with random tokens. No `[MASK]` replacements or unchanged tokens.
|
||||||
|
|
||||||
|
4. Balanced replacement:
|
||||||
|
- `mask_replace_prob=0.5`, `random_replace_prob=0.4`.
|
||||||
|
- Expect 50% of masked tokens replaced with `[MASK]`, 40% replaced with random tokens, and 10% left unchanged.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The sum of `mask_replace_prob` and `random_replace_prob` must not exceed 1. If their sum is less than 1, the
|
||||||
|
remaining proportion will consist of masked tokens left unchanged.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
"""
|
||||||
|
|
||||||
tokenizer: PreTrainedTokenizerBase
|
tokenizer: PreTrainedTokenizerBase
|
||||||
mlm: bool = True
|
mlm: bool = True
|
||||||
mlm_probability: float = 0.15
|
mlm_probability: float = 0.15
|
||||||
|
mask_replace_prob: float = 0.8
|
||||||
|
random_replace_prob: float = 0.1
|
||||||
pad_to_multiple_of: Optional[int] = None
|
pad_to_multiple_of: Optional[int] = None
|
||||||
tf_experimental_compile: bool = False
|
tf_experimental_compile: bool = False
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
@ -720,6 +751,15 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
|||||||
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
||||||
"You should pass `mlm=False` to train on causal language modeling instead."
|
"You should pass `mlm=False` to train on causal language modeling instead."
|
||||||
)
|
)
|
||||||
|
if self.mlm_probability < 0 or self.mlm_probability > 1:
|
||||||
|
raise ValueError("mlm_probability should be between 0 and 1.")
|
||||||
|
if self.mask_replace_prob + self.random_replace_prob > 1:
|
||||||
|
raise ValueError("The sum of mask_replace_prob and random_replace_prob should not exceed 1")
|
||||||
|
if self.mask_replace_prob < 0 or self.mask_replace_prob > 1:
|
||||||
|
raise ValueError("mask_replace_prob should be between 0 and 1.")
|
||||||
|
if self.random_replace_prob < 0 or self.random_replace_prob > 1:
|
||||||
|
raise ValueError("random_replace_prob should be between 0 and 1.")
|
||||||
|
|
||||||
if self.tf_experimental_compile:
|
if self.tf_experimental_compile:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
@ -749,18 +789,28 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
|||||||
# Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
|
# Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
|
||||||
labels = tf.where(masked_indices, inputs, -100)
|
labels = tf.where(masked_indices, inputs, -100)
|
||||||
|
|
||||||
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||||
indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
|
indices_replaced = self.tf_bernoulli(input_shape, self.mask_replace_prob) & masked_indices
|
||||||
|
|
||||||
inputs = tf.where(indices_replaced, mask_token_id, inputs)
|
inputs = tf.where(indices_replaced, mask_token_id, inputs)
|
||||||
|
|
||||||
# 10% of the time, we replace masked input tokens with random word
|
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
|
||||||
indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
|
return inputs, labels
|
||||||
|
|
||||||
|
remaining_prob = 1 - self.mask_replace_prob
|
||||||
|
# scaling the random_replace_prob to the remaining probability for example if
|
||||||
|
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
|
||||||
|
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
|
||||||
|
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
|
||||||
|
# random_replace_prob% of the time, we replace masked input tokens with random word
|
||||||
|
indices_random = (
|
||||||
|
self.tf_bernoulli(input_shape, random_replace_prob_scaled) & masked_indices & ~indices_replaced
|
||||||
|
)
|
||||||
random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)
|
random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)
|
||||||
|
|
||||||
inputs = tf.where(indices_random, random_words, inputs)
|
inputs = tf.where(indices_random, random_words, inputs)
|
||||||
|
|
||||||
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
# The rest of the time ((1-random_replace_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged
|
||||||
return inputs, labels
|
return inputs, labels
|
||||||
|
|
||||||
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
||||||
@ -849,16 +899,29 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
|||||||
masked_indices = torch.bernoulli(probability_matrix).bool()
|
masked_indices = torch.bernoulli(probability_matrix).bool()
|
||||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||||
|
|
||||||
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||||
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, self.mask_replace_prob)).bool() & masked_indices
|
||||||
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
||||||
|
|
||||||
# 10% of the time, we replace masked input tokens with random word
|
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
|
||||||
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
return inputs, labels
|
||||||
|
|
||||||
|
remaining_prob = 1 - self.mask_replace_prob
|
||||||
|
# scaling the random_replace_prob to the remaining probability for example if
|
||||||
|
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
|
||||||
|
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
|
||||||
|
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
|
||||||
|
|
||||||
|
# random_replace_prob% of the time, we replace masked input tokens with random word
|
||||||
|
indices_random = (
|
||||||
|
torch.bernoulli(torch.full(labels.shape, random_replace_prob_scaled)).bool()
|
||||||
|
& masked_indices
|
||||||
|
& ~indices_replaced
|
||||||
|
)
|
||||||
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
||||||
inputs[indices_random] = random_words[indices_random]
|
inputs[indices_random] = random_words[indices_random]
|
||||||
|
|
||||||
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
# The rest of the time ((1-random_replace_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged
|
||||||
return inputs, labels
|
return inputs, labels
|
||||||
|
|
||||||
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
||||||
@ -905,14 +968,24 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
|||||||
masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
|
masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
|
||||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||||
|
|
||||||
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||||
indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
|
indices_replaced = (
|
||||||
|
np.random.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
|
||||||
|
)
|
||||||
inputs[indices_replaced] = self.tokenizer.mask_token_id
|
inputs[indices_replaced] = self.tokenizer.mask_token_id
|
||||||
|
|
||||||
# 10% of the time, we replace masked input tokens with random word
|
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
|
||||||
# indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
return inputs, labels
|
||||||
|
|
||||||
|
remaining_prob = 1 - self.mask_replace_prob
|
||||||
|
# scaling the random_replace_prob to the remaining probability for example if
|
||||||
|
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
|
||||||
|
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
|
||||||
|
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
|
||||||
indices_random = (
|
indices_random = (
|
||||||
np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
|
np.random.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
|
||||||
|
& masked_indices
|
||||||
|
& ~indices_replaced
|
||||||
)
|
)
|
||||||
random_words = np.random.randint(
|
random_words = np.random.randint(
|
||||||
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
|
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
|
||||||
|
@ -1020,6 +1020,52 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(tf.reduce_any(masked_tokens))
|
self.assertTrue(tf.reduce_any(masked_tokens))
|
||||||
# self.assertTrue(all(x == -100 for x in batch["labels"].numpy()[~masked_tokens.numpy()].tolist()))
|
# self.assertTrue(all(x == -100 for x in batch["labels"].numpy()[~masked_tokens.numpy()].tolist()))
|
||||||
|
|
||||||
|
def test_probability_sum_error(self):
|
||||||
|
"""Test that the sum of mask_replace_prob and random_replace_prob exceeding 1 raises an error."""
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
DataCollatorForLanguageModeling(tokenizer=tokenizer, mask_replace_prob=0.9, random_replace_prob=0.2)
|
||||||
|
|
||||||
|
def test_all_mask_replacement(self):
|
||||||
|
"""Test behavior when mask_replace_prob=1."""
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
|
||||||
|
# pytorch call
|
||||||
|
collator = DataCollatorForLanguageModeling(
|
||||||
|
tokenizer=tokenizer, mask_replace_prob=1, random_replace_prob=0, return_tensors="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = torch.tensor([0, 1, 2, 3, 4, 5])
|
||||||
|
features = [{"input_ids": inputs} for _ in range(8)]
|
||||||
|
batch = collator(features)
|
||||||
|
|
||||||
|
# confirm that every token is either the original token or [MASK]
|
||||||
|
self.assertTrue(torch.all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id)))
|
||||||
|
|
||||||
|
# tf call
|
||||||
|
collator = DataCollatorForLanguageModeling(
|
||||||
|
tokenizer=tokenizer, mask_replace_prob=1, random_replace_prob=0, return_tensors="tf"
|
||||||
|
)
|
||||||
|
inputs = tf.constant([0, 1, 2, 3, 4, 5])
|
||||||
|
features = [{"input_ids": inputs} for _ in range(8)]
|
||||||
|
batch = collator(features)
|
||||||
|
|
||||||
|
# confirm that every token is either the original token or [MASK]
|
||||||
|
self.assertTrue(
|
||||||
|
tf.reduce_all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id))
|
||||||
|
)
|
||||||
|
|
||||||
|
# numpy call
|
||||||
|
collator = DataCollatorForLanguageModeling(
|
||||||
|
tokenizer=tokenizer, mask_replace_prob=1, random_replace_prob=0, return_tensors="np"
|
||||||
|
)
|
||||||
|
inputs = np.array([0, 1, 2, 3, 4, 5])
|
||||||
|
features = [{"input_ids": inputs} for _ in range(8)]
|
||||||
|
batch = collator(features)
|
||||||
|
|
||||||
|
# confirm that every token is either the original token or [MASK]
|
||||||
|
self.assertTrue(np.all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id)))
|
||||||
|
|
||||||
def test_data_collator_for_language_modeling(self):
|
def test_data_collator_for_language_modeling(self):
|
||||||
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||||
pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
|
pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
|
||||||
|
Loading…
Reference in New Issue
Block a user