mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Added support for seed in DataCollatorForWholeWordMask
(#36903)
* Added support for seed in `DataCollatorForWholeWordMask`, and also wrote tests. Also fixed bugs where the code hardcoded values for mask replacement probability and random replacement probability, instead of using the values passed by the user. * formatting issues * Used better way to generate seed in TF. Made tests more consistent.
This commit is contained in:
parent
5932606d8e
commit
48385aa4f4
@ -1193,6 +1193,11 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
||||
</Tip>"""
|
||||
|
||||
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
||||
if self.seed and self.generator is None:
|
||||
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
|
||||
# If no seed supplied, we will use the global RNG
|
||||
self.create_rng()
|
||||
|
||||
if isinstance(examples[0], Mapping):
|
||||
input_ids = [e["input_ids"] for e in examples]
|
||||
else:
|
||||
@ -1223,6 +1228,11 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
||||
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
||||
import tensorflow as tf
|
||||
|
||||
if self.seed and self.generator is None:
|
||||
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
|
||||
# If no seed supplied, we will use the global RNG
|
||||
self.create_rng()
|
||||
|
||||
if isinstance(examples[0], Mapping):
|
||||
input_ids = [e["input_ids"] for e in examples]
|
||||
else:
|
||||
@ -1251,6 +1261,11 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
||||
return {"input_ids": inputs, "labels": labels}
|
||||
|
||||
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
||||
if self.seed and self.generator is None:
|
||||
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
|
||||
# If no seed supplied, we will use the global RNG
|
||||
self.create_rng()
|
||||
|
||||
if isinstance(examples[0], Mapping):
|
||||
input_ids = [e["input_ids"] for e in examples]
|
||||
else:
|
||||
@ -1278,6 +1293,30 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
||||
inputs, labels = self.numpy_mask_tokens(batch_input, batch_mask)
|
||||
return {"input_ids": inputs, "labels": labels}
|
||||
|
||||
def _shuffle(self, cand_indexes):
|
||||
# if no seed, just use random's shuffle
|
||||
if self.seed is None:
|
||||
random.shuffle(cand_indexes)
|
||||
return cand_indexes
|
||||
|
||||
# if seed is provided, use the generator to shuffle
|
||||
if self.return_tensors == "pt":
|
||||
import torch
|
||||
|
||||
indices = torch.randperm(len(cand_indexes), generator=self.generator)
|
||||
return [cand_indexes[i] for i in indices]
|
||||
|
||||
elif self.return_tensors == "tf":
|
||||
import tensorflow as tf
|
||||
|
||||
seed = self.generator.make_seeds(2)[0]
|
||||
indices = tf.random.experimental.stateless_shuffle(tf.range(len(cand_indexes)), seed=seed).numpy().tolist()
|
||||
return [cand_indexes[i] for i in indices]
|
||||
|
||||
elif self.return_tensors == "np":
|
||||
self.generator.shuffle(cand_indexes)
|
||||
return cand_indexes
|
||||
|
||||
def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
|
||||
"""
|
||||
Get 0/1 labels for masked tokens with whole word mask proxy
|
||||
@ -1298,7 +1337,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
||||
else:
|
||||
cand_indexes.append([i])
|
||||
|
||||
random.shuffle(cand_indexes)
|
||||
cand_indexes = self._shuffle(cand_indexes)
|
||||
num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
|
||||
masked_lms = []
|
||||
covered_indexes = set()
|
||||
@ -1346,16 +1385,32 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
||||
masked_indices = probability_matrix.bool()
|
||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||
|
||||
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
||||
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||
indices_replaced = (
|
||||
torch.bernoulli(torch.full(labels.shape, self.mask_replace_prob), generator=self.generator).bool()
|
||||
& masked_indices
|
||||
)
|
||||
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
||||
|
||||
# 10% of the time, we replace masked input tokens with random word
|
||||
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
||||
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
||||
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
|
||||
return inputs, labels
|
||||
|
||||
remaining_prob = 1 - self.mask_replace_prob
|
||||
# scaling the random_replace_prob to the remaining probability for example if
|
||||
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
|
||||
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
|
||||
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
|
||||
|
||||
# random_replacement_prob% of the time, we replace masked input tokens with random word
|
||||
indices_random = (
|
||||
torch.bernoulli(torch.full(labels.shape, random_replace_prob_scaled), generator=self.generator).bool()
|
||||
& masked_indices
|
||||
& ~indices_replaced
|
||||
)
|
||||
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long, generator=self.generator)
|
||||
inputs[indices_random] = random_words[indices_random]
|
||||
|
||||
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
||||
# The rest of the time ((1-random_replacement_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged
|
||||
return inputs, labels
|
||||
|
||||
def tf_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
|
||||
@ -1387,17 +1442,35 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
||||
# Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
|
||||
labels = tf.where(masked_indices, inputs, -100)
|
||||
|
||||
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||
indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
|
||||
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||
indices_replaced = self.tf_bernoulli(input_shape, self.mask_replace_prob, self.generator) & masked_indices
|
||||
|
||||
inputs = tf.where(indices_replaced, self.tokenizer.mask_token_id, inputs)
|
||||
|
||||
# 10% of the time, we replace masked input tokens with random word
|
||||
indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
|
||||
random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)
|
||||
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
|
||||
return inputs, labels
|
||||
|
||||
remaining_prob = 1 - self.mask_replace_prob
|
||||
# scaling the random_replace_prob to the remaining probability for example if
|
||||
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
|
||||
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
|
||||
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
|
||||
|
||||
# random_replace_prob% of the time, we replace masked input tokens with random word
|
||||
indices_random = (
|
||||
self.tf_bernoulli(input_shape, random_replace_prob_scaled, self.generator)
|
||||
& masked_indices
|
||||
& ~indices_replaced
|
||||
)
|
||||
|
||||
if self.generator:
|
||||
random_words = self.generator.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)
|
||||
else:
|
||||
random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)
|
||||
|
||||
inputs = tf.where(indices_random, random_words, inputs)
|
||||
|
||||
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
||||
# The rest of the time ((1-mask_replace_prob-random_replace_prob)% of the time) we keep the masked input tokens unchanged
|
||||
return inputs, labels
|
||||
|
||||
def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
|
||||
@ -1425,19 +1498,44 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
||||
|
||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||
|
||||
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||
indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
|
||||
# mask_replacement_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||
if self.generator:
|
||||
indices_replaced = (
|
||||
self.generator.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
|
||||
)
|
||||
else:
|
||||
indices_replaced = (
|
||||
np.random.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
|
||||
)
|
||||
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
||||
|
||||
# 10% of the time, we replace masked input tokens with random word
|
||||
# indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
||||
indices_random = (
|
||||
np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
|
||||
)
|
||||
random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
|
||||
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
|
||||
return inputs, labels
|
||||
|
||||
remaining_prob = 1 - self.mask_replace_prob
|
||||
# scaling the random_replace_prob to the remaining probability for example if
|
||||
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
|
||||
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
|
||||
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
|
||||
|
||||
if self.generator:
|
||||
indices_random = (
|
||||
self.generator.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
|
||||
& masked_indices
|
||||
& ~indices_replaced
|
||||
)
|
||||
random_words = self.generator.integers(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
|
||||
else:
|
||||
indices_random = (
|
||||
np.random.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
|
||||
& masked_indices
|
||||
& ~indices_replaced
|
||||
)
|
||||
random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
|
||||
|
||||
inputs[indices_random] = random_words[indices_random]
|
||||
|
||||
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
||||
# The rest of the time ((1-mask_replace_prob-random_replace_prob)% of the time) we keep the masked input tokens unchanged
|
||||
return inputs, labels
|
||||
|
||||
|
||||
|
@ -445,6 +445,86 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
||||
|
||||
def test_data_collator_for_whole_word_mask_with_seed(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
|
||||
|
||||
# check if seed is respected between two different DataCollatorForWholeWordMask instances
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42)
|
||||
batch_1 = data_collator(features)
|
||||
self.assertEqual(batch_1["input_ids"].shape, torch.Size((2, 1000)))
|
||||
self.assertEqual(batch_1["labels"].shape, torch.Size((2, 1000)))
|
||||
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42)
|
||||
batch_2 = data_collator(features)
|
||||
self.assertEqual(batch_2["input_ids"].shape, torch.Size((2, 1000)))
|
||||
self.assertEqual(batch_2["labels"].shape, torch.Size((2, 1000)))
|
||||
|
||||
self.assertTrue(torch.all(batch_1["input_ids"] == batch_2["input_ids"]))
|
||||
self.assertTrue(torch.all(batch_1["labels"] == batch_2["labels"]))
|
||||
|
||||
# check if seed is respected in multiple workers situation
|
||||
features = [{"input_ids": list(range(1000))} for _ in range(10)]
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
features,
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=42),
|
||||
)
|
||||
|
||||
batch_3_input_ids = []
|
||||
batch_3_labels = []
|
||||
for batch in dataloader:
|
||||
batch_3_input_ids.append(batch["input_ids"])
|
||||
batch_3_labels.append(batch["labels"])
|
||||
|
||||
batch_3_input_ids = torch.stack(batch_3_input_ids)
|
||||
batch_3_labels = torch.stack(batch_3_labels)
|
||||
self.assertEqual(batch_3_input_ids.shape, torch.Size((5, 2, 1000)))
|
||||
self.assertEqual(batch_3_labels.shape, torch.Size((5, 2, 1000)))
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
features,
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=42),
|
||||
)
|
||||
|
||||
batch_4_input_ids = []
|
||||
batch_4_labels = []
|
||||
for batch in dataloader:
|
||||
batch_4_input_ids.append(batch["input_ids"])
|
||||
batch_4_labels.append(batch["labels"])
|
||||
batch_4_input_ids = torch.stack(batch_4_input_ids)
|
||||
batch_4_labels = torch.stack(batch_4_labels)
|
||||
self.assertEqual(batch_4_input_ids.shape, torch.Size((5, 2, 1000)))
|
||||
self.assertEqual(batch_4_labels.shape, torch.Size((5, 2, 1000)))
|
||||
|
||||
self.assertTrue(torch.all(batch_3_input_ids == batch_4_input_ids))
|
||||
self.assertTrue(torch.all(batch_3_labels == batch_4_labels))
|
||||
|
||||
# try with different seed
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
features,
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=43),
|
||||
)
|
||||
|
||||
batch_5_input_ids = []
|
||||
batch_5_labels = []
|
||||
for batch in dataloader:
|
||||
batch_5_input_ids.append(batch["input_ids"])
|
||||
batch_5_labels.append(batch["labels"])
|
||||
batch_5_input_ids = torch.stack(batch_5_input_ids)
|
||||
batch_5_labels = torch.stack(batch_5_labels)
|
||||
self.assertEqual(batch_5_input_ids.shape, torch.Size((5, 2, 1000)))
|
||||
self.assertEqual(batch_5_labels.shape, torch.Size((5, 2, 1000)))
|
||||
|
||||
self.assertFalse(torch.all(batch_3_input_ids == batch_5_input_ids))
|
||||
self.assertFalse(torch.all(batch_3_labels == batch_5_labels))
|
||||
|
||||
def test_plm(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||
@ -1199,6 +1279,33 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
|
||||
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])
|
||||
|
||||
def test_data_collator_for_whole_word_mask_with_seed(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
|
||||
|
||||
# check if seed is respected between two different DataCollatorForWholeWordMask instances
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="tf")
|
||||
batch_1 = data_collator(features)
|
||||
self.assertEqual(batch_1["input_ids"].shape.as_list(), [2, 1000])
|
||||
self.assertEqual(batch_1["labels"].shape.as_list(), [2, 1000])
|
||||
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="tf")
|
||||
batch_2 = data_collator(features)
|
||||
self.assertEqual(batch_2["input_ids"].shape.as_list(), [2, 1000])
|
||||
self.assertEqual(batch_2["labels"].shape.as_list(), [2, 1000])
|
||||
|
||||
self.assertTrue(np.all(batch_1["input_ids"] == batch_2["input_ids"]))
|
||||
self.assertTrue(np.all(batch_1["labels"] == batch_2["labels"]))
|
||||
|
||||
# try with different seed
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=43, return_tensors="tf")
|
||||
batch_3 = data_collator(features)
|
||||
self.assertEqual(batch_3["input_ids"].shape.as_list(), [2, 1000])
|
||||
self.assertEqual(batch_3["labels"].shape.as_list(), [2, 1000])
|
||||
|
||||
self.assertFalse(np.all(batch_1["input_ids"] == batch_3["input_ids"]))
|
||||
self.assertFalse(np.all(batch_1["labels"] == batch_3["labels"]))
|
||||
|
||||
def test_plm(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||
@ -1920,6 +2027,32 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
||||
self.assertEqual(batch["labels"].shape, (2, 10))
|
||||
|
||||
def test_data_collator_for_whole_word_mask_with_seed(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
|
||||
|
||||
# check if seed is respected between two different DataCollatorForWholeWordMask instances
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="np")
|
||||
batch_1 = data_collator(features)
|
||||
self.assertEqual(batch_1["input_ids"].shape, (2, 1000))
|
||||
self.assertEqual(batch_1["labels"].shape, (2, 1000))
|
||||
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="np")
|
||||
batch_2 = data_collator(features)
|
||||
self.assertEqual(batch_2["input_ids"].shape, (2, 1000))
|
||||
self.assertEqual(batch_2["labels"].shape, (2, 1000))
|
||||
|
||||
self.assertTrue(np.all(batch_1["input_ids"] == batch_2["input_ids"]))
|
||||
self.assertTrue(np.all(batch_1["labels"] == batch_2["labels"]))
|
||||
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=43, return_tensors="np")
|
||||
batch_3 = data_collator(features)
|
||||
self.assertEqual(batch_3["input_ids"].shape, (2, 1000))
|
||||
self.assertEqual(batch_3["labels"].shape, (2, 1000))
|
||||
|
||||
self.assertFalse(np.all(batch_1["input_ids"] == batch_3["input_ids"]))
|
||||
self.assertFalse(np.all(batch_1["labels"] == batch_3["labels"]))
|
||||
|
||||
def test_plm(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||
|
Loading…
Reference in New Issue
Block a user