minimal fixes to run DataCollatorForWholeWordMask with return_tensors="np" and return_tensors="tf" (#13891)

* minimal fixes to run DataCollatorForWholeWordMask with return_tensors="np" and return_tensors="tf"

* more consinstent implementation for numpy_mask_tokens
This commit is contained in:
Dean Wyatte 2021-11-03 08:36:41 -06:00 committed by GitHub
parent 671569ddf7
commit 27b1516d32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 8 deletions

View File

@ -883,14 +883,14 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
inputs, labels = self.tf_mask_tokens(batch_input, batch_mask)
return {"input_ids": inputs, "labels": labels}
def np_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)):
input_ids = [e["input_ids"] for e in examples]
else:
input_ids = examples
examples = [{"input_ids": e} for e in examples]
batch_input = _tf_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
batch_input = _numpy_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
mask_labels = []
for e in examples:
@ -1009,15 +1009,15 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)
labels = inputs.clone()
labels = tf.identity(inputs)
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
masked_indices = tf.cast(mask_labels, tf.bool)
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels
]
masked_indices = masked_indices & ~tf.convert_to_tensor(special_tokens_mask, dtype=tf.bool)
masked_indices = masked_indices & ~tf.cast(special_tokens_mask, dtype=tf.bool)
if self.tokenizer._pad_token is not None:
padding_mask = inputs == self.tokenizer.pad_token_id
masked_indices = masked_indices & ~padding_mask
@ -1073,9 +1073,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
indices_random = (
np.random.binomial(1, 0.5, size=labels.shape).astype(np.bool) & masked_indices & ~indices_replaced
)
random_words = np.random.randint(
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
)
random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged

View File

@ -24,6 +24,7 @@ from transformers import (
DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithPadding,
default_data_collator,
is_tf_available,
@ -224,6 +225,16 @@ class DataCollatorIntegrationTest(unittest.TestCase):
pad_features = [list(range(5)), list(range(10))]
self._test_no_pad_and_pad(no_pad_features, pad_features)
def test_data_collator_for_whole_word_mask(self):
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
tokenizer = BertTokenizer(self.vocab_file)
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt")
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
def test_plm(self):
tokenizer = BertTokenizer(self.vocab_file)
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
@ -488,6 +499,16 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
pad_features = [list(range(5)), list(range(10))]
self._test_no_pad_and_pad(no_pad_features, pad_features)
def test_data_collator_for_whole_word_mask(self):
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
tokenizer = BertTokenizer(self.vocab_file)
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="tf")
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])
def test_plm(self):
tokenizer = BertTokenizer(self.vocab_file)
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
@ -750,6 +771,16 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
pad_features = [list(range(5)), list(range(10))]
self._test_no_pad_and_pad(no_pad_features, pad_features)
def test_data_collator_for_whole_word_mask(self):
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
tokenizer = BertTokenizer(self.vocab_file)
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="np")
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, (2, 10))
self.assertEqual(batch["labels"].shape, (2, 10))
def test_plm(self):
tokenizer = BertTokenizer(self.vocab_file)
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]