mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
TF/Numpy variants for all DataCollator classes (#13105)
* Adding a TF variant of the DataCollatorForTokenClassification to get feedback * Added a Numpy variant and a post_init check to fail early if a missing import is found * Fixed call to Numpy variant * Added a couple more of the collators * Update src/transformers/data/data_collator.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Fixes, style pass, finished DataCollatorForSeqToSeq * Added all the LanguageModeling DataCollators, except SOP and PermutationLanguageModeling * Adding DataCollatorForPermutationLanguageModeling * Style pass * Add missing `__call__` for PLM * Remove `post_init` checks for frameworks because the imports inside them were making us fail code quality checks * Remove unused imports * First attempt at some TF tests * A second attempt to make any of those tests actually work * TF tests, round three * TF tests, round four * TF tests, round five * TF tests, all enabled! * Style pass * Merging tests into `test_data_collator.py` * Merging tests into `test_data_collator.py` * Fixing up test imports * Fixing up test imports * Trying shuffling the conditionals around * Commenting out non-functional old tests * Completed all tests for all three frameworks * Style pass * Fixed test typo * Style pass * Move standard `__call__` method to mixin * Rearranged imports for `test_data_collator` * Fix data collator typo "torch" -> "pt" * Fixed the most embarrassingly obvious bug * Update src/transformers/data/data_collator.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Renaming mixin * Updating docs Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Dalton Walker <dalton_walker@icloud.com> Co-authored-by: Andrew Romans <andrew.romans@hotmail.com>
This commit is contained in:
parent
74b3344fbc
commit
854260ca44
@ -54,18 +54,18 @@ DataCollatorForLanguageModeling
|
|||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
.. autoclass:: transformers.data.data_collator.DataCollatorForLanguageModeling
|
.. autoclass:: transformers.data.data_collator.DataCollatorForLanguageModeling
|
||||||
:members: mask_tokens
|
:members: numpy_mask_tokens, tf_mask_tokens, torch_mask_tokens
|
||||||
|
|
||||||
|
|
||||||
DataCollatorForWholeWordMask
|
DataCollatorForWholeWordMask
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
.. autoclass:: transformers.data.data_collator.DataCollatorForWholeWordMask
|
.. autoclass:: transformers.data.data_collator.DataCollatorForWholeWordMask
|
||||||
:members: mask_tokens
|
:members: numpy_mask_tokens, tf_mask_tokens, torch_mask_tokens
|
||||||
|
|
||||||
|
|
||||||
DataCollatorForPermutationLanguageModeling
|
DataCollatorForPermutationLanguageModeling
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
.. autoclass:: transformers.data.data_collator.DataCollatorForPermutationLanguageModeling
|
.. autoclass:: transformers.data.data_collator.DataCollatorForPermutationLanguageModeling
|
||||||
:members: mask_tokens
|
:members: numpy_mask_tokens, tf_mask_tokens, torch_mask_tokens
|
||||||
|
@ -81,6 +81,17 @@ _import_structure = {
|
|||||||
"xnli_processors",
|
"xnli_processors",
|
||||||
"xnli_tasks_num_labels",
|
"xnli_tasks_num_labels",
|
||||||
],
|
],
|
||||||
|
"data.data_collator": [
|
||||||
|
"DataCollator",
|
||||||
|
"DataCollatorForLanguageModeling",
|
||||||
|
"DataCollatorForPermutationLanguageModeling",
|
||||||
|
"DataCollatorForSeq2Seq",
|
||||||
|
"DataCollatorForSOP",
|
||||||
|
"DataCollatorForTokenClassification",
|
||||||
|
"DataCollatorForWholeWordMask",
|
||||||
|
"DataCollatorWithPadding",
|
||||||
|
"default_data_collator",
|
||||||
|
],
|
||||||
"feature_extraction_sequence_utils": ["BatchFeature", "SequenceFeatureExtractor"],
|
"feature_extraction_sequence_utils": ["BatchFeature", "SequenceFeatureExtractor"],
|
||||||
"file_utils": [
|
"file_utils": [
|
||||||
"CONFIG_NAME",
|
"CONFIG_NAME",
|
||||||
@ -460,17 +471,6 @@ else:
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
|
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
|
||||||
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
|
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
|
||||||
_import_structure["data.data_collator"] = [
|
|
||||||
"DataCollator",
|
|
||||||
"DataCollatorForLanguageModeling",
|
|
||||||
"DataCollatorForPermutationLanguageModeling",
|
|
||||||
"DataCollatorForSeq2Seq",
|
|
||||||
"DataCollatorForSOP",
|
|
||||||
"DataCollatorForTokenClassification",
|
|
||||||
"DataCollatorForWholeWordMask",
|
|
||||||
"DataCollatorWithPadding",
|
|
||||||
"default_data_collator",
|
|
||||||
]
|
|
||||||
_import_structure["data.datasets"] = [
|
_import_structure["data.datasets"] = [
|
||||||
"GlueDataset",
|
"GlueDataset",
|
||||||
"GlueDataTrainingArguments",
|
"GlueDataTrainingArguments",
|
||||||
@ -1830,6 +1830,17 @@ if TYPE_CHECKING:
|
|||||||
xnli_processors,
|
xnli_processors,
|
||||||
xnli_tasks_num_labels,
|
xnli_tasks_num_labels,
|
||||||
)
|
)
|
||||||
|
from .data.data_collator import (
|
||||||
|
DataCollator,
|
||||||
|
DataCollatorForLanguageModeling,
|
||||||
|
DataCollatorForPermutationLanguageModeling,
|
||||||
|
DataCollatorForSeq2Seq,
|
||||||
|
DataCollatorForSOP,
|
||||||
|
DataCollatorForTokenClassification,
|
||||||
|
DataCollatorForWholeWordMask,
|
||||||
|
DataCollatorWithPadding,
|
||||||
|
default_data_collator,
|
||||||
|
)
|
||||||
|
|
||||||
# Feature Extractor
|
# Feature Extractor
|
||||||
from .feature_extraction_utils import BatchFeature, SequenceFeatureExtractor
|
from .feature_extraction_utils import BatchFeature, SequenceFeatureExtractor
|
||||||
@ -2174,17 +2185,6 @@ if TYPE_CHECKING:
|
|||||||
# Benchmarks
|
# Benchmarks
|
||||||
from .benchmark.benchmark import PyTorchBenchmark
|
from .benchmark.benchmark import PyTorchBenchmark
|
||||||
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
|
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
|
||||||
from .data.data_collator import (
|
|
||||||
DataCollator,
|
|
||||||
DataCollatorForLanguageModeling,
|
|
||||||
DataCollatorForPermutationLanguageModeling,
|
|
||||||
DataCollatorForSeq2Seq,
|
|
||||||
DataCollatorForSOP,
|
|
||||||
DataCollatorForTokenClassification,
|
|
||||||
DataCollatorForWholeWordMask,
|
|
||||||
DataCollatorWithPadding,
|
|
||||||
default_data_collator,
|
|
||||||
)
|
|
||||||
from .data.datasets import (
|
from .data.datasets import (
|
||||||
GlueDataset,
|
GlueDataset,
|
||||||
GlueDataTrainingArguments,
|
GlueDataTrainingArguments,
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -12,62 +12,6 @@ class PyTorchBenchmarkArguments:
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class DataCollator:
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch"])
|
|
||||||
|
|
||||||
|
|
||||||
class DataCollatorForLanguageModeling:
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch"])
|
|
||||||
|
|
||||||
|
|
||||||
class DataCollatorForPermutationLanguageModeling:
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch"])
|
|
||||||
|
|
||||||
|
|
||||||
class DataCollatorForSeq2Seq:
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch"])
|
|
||||||
|
|
||||||
|
|
||||||
class DataCollatorForSOP:
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch"])
|
|
||||||
|
|
||||||
|
|
||||||
class DataCollatorForTokenClassification:
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch"])
|
|
||||||
|
|
||||||
|
|
||||||
class DataCollatorForWholeWordMask:
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch"])
|
|
||||||
|
|
||||||
|
|
||||||
class DataCollatorWithPadding:
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch"])
|
|
||||||
|
|
||||||
|
|
||||||
def default_data_collator(*args, **kwargs):
|
|
||||||
requires_backends(default_data_collator, ["torch"])
|
|
||||||
|
|
||||||
|
|
||||||
class GlueDataset:
|
class GlueDataset:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
@ -17,20 +17,27 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import BertTokenizer, is_torch_available, set_seed
|
import numpy as np
|
||||||
from transformers.testing_utils import require_torch
|
|
||||||
|
from transformers import (
|
||||||
|
BertTokenizer,
|
||||||
|
DataCollatorForLanguageModeling,
|
||||||
|
DataCollatorForPermutationLanguageModeling,
|
||||||
|
DataCollatorForTokenClassification,
|
||||||
|
DataCollatorWithPadding,
|
||||||
|
default_data_collator,
|
||||||
|
is_tf_available,
|
||||||
|
is_torch_available,
|
||||||
|
set_seed,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import require_tf, require_torch
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import (
|
if is_tf_available():
|
||||||
DataCollatorForLanguageModeling,
|
import tensorflow as tf
|
||||||
DataCollatorForPermutationLanguageModeling,
|
|
||||||
DataCollatorForTokenClassification,
|
|
||||||
DataCollatorWithPadding,
|
|
||||||
default_data_collator,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@ -61,14 +68,14 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
|
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
|
||||||
|
|
||||||
# Features can already be tensors
|
# Features can already be tensors
|
||||||
features = [{"label": i, "inputs": torch.randint(10, [10])} for i in range(8)]
|
features = [{"label": i, "inputs": np.random.randint(0, 10, [10])} for i in range(8)]
|
||||||
batch = default_data_collator(features)
|
batch = default_data_collator(features)
|
||||||
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
|
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
|
||||||
self.assertEqual(batch["labels"].dtype, torch.long)
|
self.assertEqual(batch["labels"].dtype, torch.long)
|
||||||
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
|
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
|
||||||
|
|
||||||
# Labels can already be tensors
|
# Labels can already be tensors
|
||||||
features = [{"label": torch.tensor(i), "inputs": torch.randint(10, [10])} for i in range(8)]
|
features = [{"label": torch.tensor(i), "inputs": np.random.randint(0, 10, [10])} for i in range(8)]
|
||||||
batch = default_data_collator(features)
|
batch = default_data_collator(features)
|
||||||
self.assertEqual(batch["labels"].dtype, torch.long)
|
self.assertEqual(batch["labels"].dtype, torch.long)
|
||||||
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
|
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
|
||||||
@ -238,7 +245,7 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(batch["target_mapping"].shape, torch.Size((2, 10, 10)))
|
self.assertEqual(batch["target_mapping"].shape, torch.Size((2, 10, 10)))
|
||||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
||||||
|
|
||||||
example = [torch.randint(5, [5])]
|
example = [np.random.randint(0, 5, [5])]
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
# Expect error due to odd sequence length
|
# Expect error due to odd sequence length
|
||||||
data_collator(example)
|
data_collator(example)
|
||||||
@ -290,3 +297,529 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 8)))
|
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 8)))
|
||||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 8)))
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 8)))
|
||||||
self.assertEqual(batch["sentence_order_label"].shape, torch.Size((2,)))
|
self.assertEqual(batch["sentence_order_label"].shape, torch.Size((2,)))
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFDataCollatorIntegrationTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.tmpdirname = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
||||||
|
self.vocab_file = os.path.join(self.tmpdirname, "vocab.txt")
|
||||||
|
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||||
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
shutil.rmtree(self.tmpdirname)
|
||||||
|
|
||||||
|
def test_default_with_dict(self):
|
||||||
|
features = [{"label": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
||||||
|
batch = default_data_collator(features, return_tensors="tf")
|
||||||
|
self.assertEqual(batch["labels"].numpy().tolist(), list(range(8)))
|
||||||
|
self.assertEqual(batch["labels"].dtype, tf.int64)
|
||||||
|
self.assertEqual(batch["inputs"].shape.as_list(), [8, 6])
|
||||||
|
|
||||||
|
# With label_ids
|
||||||
|
features = [{"label_ids": [0, 1, 2], "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
||||||
|
batch = default_data_collator(features, return_tensors="tf")
|
||||||
|
self.assertEqual(batch["labels"].numpy().tolist(), ([[0, 1, 2]] * 8))
|
||||||
|
self.assertEqual(batch["labels"].dtype, tf.int64)
|
||||||
|
self.assertEqual(batch["inputs"].shape.as_list(), [8, 6])
|
||||||
|
|
||||||
|
# Features can already be tensors
|
||||||
|
features = [{"label": i, "inputs": np.random.randint(0, 10, [10])} for i in range(8)]
|
||||||
|
batch = default_data_collator(features, return_tensors="tf")
|
||||||
|
self.assertEqual(batch["labels"].numpy().tolist(), (list(range(8))))
|
||||||
|
self.assertEqual(batch["labels"].dtype, tf.int64)
|
||||||
|
self.assertEqual(batch["inputs"].shape.as_list(), [8, 10])
|
||||||
|
|
||||||
|
# Labels can already be tensors
|
||||||
|
features = [{"label": np.array(i), "inputs": np.random.randint(0, 10, [10])} for i in range(8)]
|
||||||
|
batch = default_data_collator(features, return_tensors="tf")
|
||||||
|
self.assertEqual(batch["labels"].dtype, tf.int64)
|
||||||
|
self.assertEqual(batch["labels"].numpy().tolist(), list(range(8)))
|
||||||
|
self.assertEqual(batch["labels"].dtype, tf.int64)
|
||||||
|
self.assertEqual(batch["inputs"].shape.as_list(), [8, 10])
|
||||||
|
|
||||||
|
def test_default_classification_and_regression(self):
|
||||||
|
data_collator = default_data_collator
|
||||||
|
|
||||||
|
features = [{"input_ids": [0, 1, 2, 3, 4], "label": i} for i in range(4)]
|
||||||
|
batch = data_collator(features, return_tensors="tf")
|
||||||
|
self.assertEqual(batch["labels"].dtype, tf.int64)
|
||||||
|
|
||||||
|
features = [{"input_ids": [0, 1, 2, 3, 4], "label": float(i)} for i in range(4)]
|
||||||
|
batch = data_collator(features, return_tensors="tf")
|
||||||
|
self.assertEqual(batch["labels"].dtype, tf.float32)
|
||||||
|
|
||||||
|
def test_default_with_no_labels(self):
|
||||||
|
features = [{"label": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
||||||
|
batch = default_data_collator(features, return_tensors="tf")
|
||||||
|
self.assertTrue("labels" not in batch)
|
||||||
|
self.assertEqual(batch["inputs"].shape.as_list(), [8, 6])
|
||||||
|
|
||||||
|
# With label_ids
|
||||||
|
features = [{"label_ids": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
||||||
|
batch = default_data_collator(features, return_tensors="tf")
|
||||||
|
self.assertTrue("labels" not in batch)
|
||||||
|
self.assertEqual(batch["inputs"].shape.as_list(), [8, 6])
|
||||||
|
|
||||||
|
def test_data_collator_with_padding(self):
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
features = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}]
|
||||||
|
|
||||||
|
data_collator = DataCollatorWithPadding(tokenizer, return_tensors="tf")
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 6])
|
||||||
|
self.assertEqual(batch["input_ids"][0].numpy().tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
||||||
|
|
||||||
|
data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=10, return_tensors="tf")
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
|
||||||
|
|
||||||
|
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8, return_tensors="tf")
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, [2, 8])
|
||||||
|
|
||||||
|
def test_data_collator_for_token_classification(self):
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
features = [
|
||||||
|
{"input_ids": [0, 1, 2], "labels": [0, 1, 2]},
|
||||||
|
{"input_ids": [0, 1, 2, 3, 4, 5], "labels": [0, 1, 2, 3, 4, 5]},
|
||||||
|
]
|
||||||
|
|
||||||
|
data_collator = DataCollatorForTokenClassification(tokenizer, return_tensors="tf")
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 6])
|
||||||
|
self.assertEqual(batch["input_ids"][0].numpy().tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 6])
|
||||||
|
self.assertEqual(batch["labels"][0].numpy().tolist(), [0, 1, 2] + [-100] * 3)
|
||||||
|
|
||||||
|
data_collator = DataCollatorForTokenClassification(
|
||||||
|
tokenizer, padding="max_length", max_length=10, return_tensors="tf"
|
||||||
|
)
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])
|
||||||
|
|
||||||
|
data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8, return_tensors="tf")
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 8])
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 8])
|
||||||
|
|
||||||
|
data_collator = DataCollatorForTokenClassification(tokenizer, label_pad_token_id=-1, return_tensors="tf")
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 6])
|
||||||
|
self.assertEqual(batch["input_ids"][0].numpy().tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 6])
|
||||||
|
self.assertEqual(batch["labels"][0].numpy().tolist(), [0, 1, 2] + [-1] * 3)
|
||||||
|
|
||||||
|
def _test_no_pad_and_pad(self, no_pad_features, pad_features):
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="tf")
|
||||||
|
batch = data_collator(no_pad_features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])
|
||||||
|
|
||||||
|
batch = data_collator(pad_features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])
|
||||||
|
|
||||||
|
data_collator = DataCollatorForLanguageModeling(
|
||||||
|
tokenizer, mlm=False, pad_to_multiple_of=8, return_tensors="tf"
|
||||||
|
)
|
||||||
|
batch = data_collator(no_pad_features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 16])
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 16])
|
||||||
|
|
||||||
|
batch = data_collator(pad_features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 16])
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 16])
|
||||||
|
|
||||||
|
tokenizer._pad_token = None
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="tf")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# Expect error due to padding token missing
|
||||||
|
data_collator(pad_features)
|
||||||
|
|
||||||
|
set_seed(42) # For reproducibility
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, return_tensors="tf")
|
||||||
|
batch = data_collator(no_pad_features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])
|
||||||
|
|
||||||
|
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
|
||||||
|
self.assertTrue(tf.reduce_any(masked_tokens))
|
||||||
|
# self.assertTrue(all(x == -100 for x in batch["labels"].numpy()[~masked_tokens.numpy()].tolist()))
|
||||||
|
|
||||||
|
batch = data_collator(pad_features, return_tensors="tf")
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])
|
||||||
|
|
||||||
|
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
|
||||||
|
self.assertTrue(tf.reduce_any(masked_tokens))
|
||||||
|
# self.assertTrue(all(x == -100 for x in batch["labels"].numpy()[~masked_tokens.numpy()].tolist()))
|
||||||
|
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8, return_tensors="tf")
|
||||||
|
batch = data_collator(no_pad_features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 16])
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 16])
|
||||||
|
|
||||||
|
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
|
||||||
|
self.assertTrue(tf.reduce_any(masked_tokens))
|
||||||
|
# self.assertTrue(all(x == -100 for x in batch["labels"].numpy()[~masked_tokens.numpy()].tolist()))
|
||||||
|
|
||||||
|
batch = data_collator(pad_features, return_tensors="tf")
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 16])
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 16])
|
||||||
|
|
||||||
|
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
|
||||||
|
self.assertTrue(tf.reduce_any(masked_tokens))
|
||||||
|
# self.assertTrue(all(x == -100 for x in batch["labels"].numpy()[~masked_tokens.numpy()].tolist()))
|
||||||
|
|
||||||
|
def test_data_collator_for_language_modeling(self):
|
||||||
|
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||||
|
pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
|
||||||
|
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||||
|
|
||||||
|
no_pad_features = [list(range(10)), list(range(10))]
|
||||||
|
pad_features = [list(range(5)), list(range(10))]
|
||||||
|
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||||
|
|
||||||
|
def test_plm(self):
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||||
|
pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
|
||||||
|
|
||||||
|
data_collator = DataCollatorForPermutationLanguageModeling(tokenizer, return_tensors="tf")
|
||||||
|
|
||||||
|
batch = data_collator(pad_features)
|
||||||
|
self.assertIsInstance(batch, dict)
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
|
||||||
|
self.assertEqual(batch["perm_mask"].shape.as_list(), [2, 10, 10])
|
||||||
|
self.assertEqual(batch["target_mapping"].shape.as_list(), [2, 10, 10])
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])
|
||||||
|
|
||||||
|
batch = data_collator(no_pad_features)
|
||||||
|
self.assertIsInstance(batch, dict)
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
|
||||||
|
self.assertEqual(batch["perm_mask"].shape.as_list(), [2, 10, 10])
|
||||||
|
self.assertEqual(batch["target_mapping"].shape.as_list(), [2, 10, 10])
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])
|
||||||
|
|
||||||
|
example = [np.random.randint(0, 5, [5])]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# Expect error due to odd sequence length
|
||||||
|
data_collator(example)
|
||||||
|
|
||||||
|
def test_nsp(self):
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
features = [
|
||||||
|
{"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i}
|
||||||
|
for i in range(2)
|
||||||
|
]
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, return_tensors="tf")
|
||||||
|
batch = data_collator(features)
|
||||||
|
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 5])
|
||||||
|
self.assertEqual(batch["token_type_ids"].shape.as_list(), [2, 5])
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 5])
|
||||||
|
self.assertEqual(batch["next_sentence_label"].shape.as_list(), [2])
|
||||||
|
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8, return_tensors="tf")
|
||||||
|
batch = data_collator(features)
|
||||||
|
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 8])
|
||||||
|
self.assertEqual(batch["token_type_ids"].shape.as_list(), [2, 8])
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 8])
|
||||||
|
self.assertEqual(batch["next_sentence_label"].shape.as_list(), [2])
|
||||||
|
|
||||||
|
def test_sop(self):
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
features = [
|
||||||
|
{
|
||||||
|
"input_ids": tf.convert_to_tensor([0, 1, 2, 3, 4]),
|
||||||
|
"token_type_ids": tf.convert_to_tensor([0, 1, 2, 3, 4]),
|
||||||
|
"sentence_order_label": i,
|
||||||
|
}
|
||||||
|
for i in range(2)
|
||||||
|
]
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, return_tensors="tf")
|
||||||
|
batch = data_collator(features)
|
||||||
|
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 5])
|
||||||
|
self.assertEqual(batch["token_type_ids"].shape.as_list(), [2, 5])
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 5])
|
||||||
|
self.assertEqual(batch["sentence_order_label"].shape.as_list(), [2])
|
||||||
|
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8, return_tensors="tf")
|
||||||
|
batch = data_collator(features)
|
||||||
|
|
||||||
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 8])
|
||||||
|
self.assertEqual(batch["token_type_ids"].shape.as_list(), [2, 8])
|
||||||
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 8])
|
||||||
|
self.assertEqual(batch["sentence_order_label"].shape.as_list(), [2])
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyDataCollatorIntegrationTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdirname = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
||||||
|
self.vocab_file = os.path.join(self.tmpdirname, "vocab.txt")
|
||||||
|
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||||
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
shutil.rmtree(self.tmpdirname)
|
||||||
|
|
||||||
|
def test_default_with_dict(self):
|
||||||
|
features = [{"label": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
||||||
|
batch = default_data_collator(features, return_tensors="np")
|
||||||
|
self.assertEqual(batch["labels"].tolist(), list(range(8)))
|
||||||
|
self.assertEqual(batch["labels"].dtype, np.int64)
|
||||||
|
self.assertEqual(batch["inputs"].shape, (8, 6))
|
||||||
|
|
||||||
|
# With label_ids
|
||||||
|
features = [{"label_ids": [0, 1, 2], "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
||||||
|
batch = default_data_collator(features, return_tensors="np")
|
||||||
|
self.assertEqual(batch["labels"].tolist(), [[0, 1, 2]] * 8)
|
||||||
|
self.assertEqual(batch["labels"].dtype, np.int64)
|
||||||
|
self.assertEqual(batch["inputs"].shape, (8, 6))
|
||||||
|
|
||||||
|
# Features can already be tensors
|
||||||
|
features = [{"label": i, "inputs": np.random.randint(0, 10, [10])} for i in range(8)]
|
||||||
|
batch = default_data_collator(features, return_tensors="np")
|
||||||
|
self.assertEqual(batch["labels"].tolist(), list(range(8)))
|
||||||
|
self.assertEqual(batch["labels"].dtype, np.int64)
|
||||||
|
self.assertEqual(batch["inputs"].shape, (8, 10))
|
||||||
|
|
||||||
|
# Labels can already be tensors
|
||||||
|
features = [{"label": np.array(i), "inputs": np.random.randint(0, 10, [10])} for i in range(8)]
|
||||||
|
batch = default_data_collator(features, return_tensors="np")
|
||||||
|
self.assertEqual(batch["labels"].dtype, np.int64)
|
||||||
|
self.assertEqual(batch["labels"].tolist(), (list(range(8))))
|
||||||
|
self.assertEqual(batch["labels"].dtype, np.int64)
|
||||||
|
self.assertEqual(batch["inputs"].shape, (8, 10))
|
||||||
|
|
||||||
|
def test_default_classification_and_regression(self):
|
||||||
|
data_collator = default_data_collator
|
||||||
|
|
||||||
|
features = [{"input_ids": [0, 1, 2, 3, 4], "label": i} for i in range(4)]
|
||||||
|
batch = data_collator(features, return_tensors="np")
|
||||||
|
self.assertEqual(batch["labels"].dtype, np.int64)
|
||||||
|
|
||||||
|
features = [{"input_ids": [0, 1, 2, 3, 4], "label": float(i)} for i in range(4)]
|
||||||
|
batch = data_collator(features, return_tensors="np")
|
||||||
|
self.assertEqual(batch["labels"].dtype, np.float32)
|
||||||
|
|
||||||
|
def test_default_with_no_labels(self):
|
||||||
|
features = [{"label": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
||||||
|
batch = default_data_collator(features, return_tensors="np")
|
||||||
|
self.assertTrue("labels" not in batch)
|
||||||
|
self.assertEqual(batch["inputs"].shape, (8, 6))
|
||||||
|
|
||||||
|
# With label_ids
|
||||||
|
features = [{"label_ids": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
||||||
|
batch = default_data_collator(features, return_tensors="np")
|
||||||
|
self.assertTrue("labels" not in batch)
|
||||||
|
self.assertEqual(batch["inputs"].shape, (8, 6))
|
||||||
|
|
||||||
|
def test_data_collator_with_padding(self):
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
features = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}]
|
||||||
|
|
||||||
|
data_collator = DataCollatorWithPadding(tokenizer, return_tensors="np")
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 6))
|
||||||
|
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
||||||
|
|
||||||
|
data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=10, return_tensors="np")
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
||||||
|
|
||||||
|
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8, return_tensors="np")
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 8))
|
||||||
|
|
||||||
|
def test_data_collator_for_token_classification(self):
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
features = [
|
||||||
|
{"input_ids": [0, 1, 2], "labels": [0, 1, 2]},
|
||||||
|
{"input_ids": [0, 1, 2, 3, 4, 5], "labels": [0, 1, 2, 3, 4, 5]},
|
||||||
|
]
|
||||||
|
|
||||||
|
data_collator = DataCollatorForTokenClassification(tokenizer, return_tensors="np")
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 6))
|
||||||
|
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 6))
|
||||||
|
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-100] * 3)
|
||||||
|
|
||||||
|
data_collator = DataCollatorForTokenClassification(
|
||||||
|
tokenizer, padding="max_length", max_length=10, return_tensors="np"
|
||||||
|
)
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 10))
|
||||||
|
|
||||||
|
data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8, return_tensors="np")
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 8))
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 8))
|
||||||
|
|
||||||
|
data_collator = DataCollatorForTokenClassification(tokenizer, label_pad_token_id=-1, return_tensors="np")
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 6))
|
||||||
|
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 6))
|
||||||
|
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-1] * 3)
|
||||||
|
|
||||||
|
def _test_no_pad_and_pad(self, no_pad_features, pad_features):
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="np")
|
||||||
|
batch = data_collator(no_pad_features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 10))
|
||||||
|
|
||||||
|
batch = data_collator(pad_features, return_tensors="np")
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 10))
|
||||||
|
|
||||||
|
data_collator = DataCollatorForLanguageModeling(
|
||||||
|
tokenizer, mlm=False, pad_to_multiple_of=8, return_tensors="np"
|
||||||
|
)
|
||||||
|
batch = data_collator(no_pad_features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 16))
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 16))
|
||||||
|
|
||||||
|
batch = data_collator(pad_features, return_tensors="np")
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 16))
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 16))
|
||||||
|
|
||||||
|
tokenizer._pad_token = None
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="np")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# Expect error due to padding token missing
|
||||||
|
data_collator(pad_features)
|
||||||
|
|
||||||
|
set_seed(42) # For reproducibility
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, return_tensors="np")
|
||||||
|
batch = data_collator(no_pad_features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 10))
|
||||||
|
|
||||||
|
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
|
||||||
|
self.assertTrue(np.any(masked_tokens))
|
||||||
|
# self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
|
||||||
|
|
||||||
|
batch = data_collator(pad_features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 10))
|
||||||
|
|
||||||
|
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
|
||||||
|
self.assertTrue(np.any(masked_tokens))
|
||||||
|
# self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
|
||||||
|
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8, return_tensors="np")
|
||||||
|
batch = data_collator(no_pad_features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 16))
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 16))
|
||||||
|
|
||||||
|
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
|
||||||
|
self.assertTrue(np.any(masked_tokens))
|
||||||
|
# self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
|
||||||
|
|
||||||
|
batch = data_collator(pad_features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 16))
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 16))
|
||||||
|
|
||||||
|
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
|
||||||
|
self.assertTrue(np.any(masked_tokens))
|
||||||
|
# self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
|
||||||
|
|
||||||
|
def test_data_collator_for_language_modeling(self):
|
||||||
|
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||||
|
pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
|
||||||
|
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||||
|
|
||||||
|
no_pad_features = [list(range(10)), list(range(10))]
|
||||||
|
pad_features = [list(range(5)), list(range(10))]
|
||||||
|
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||||
|
|
||||||
|
def test_plm(self):
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||||
|
pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
|
||||||
|
|
||||||
|
data_collator = DataCollatorForPermutationLanguageModeling(tokenizer, return_tensors="np")
|
||||||
|
|
||||||
|
batch = data_collator(pad_features)
|
||||||
|
self.assertIsInstance(batch, dict)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
||||||
|
self.assertEqual(batch["perm_mask"].shape, (2, 10, 10))
|
||||||
|
self.assertEqual(batch["target_mapping"].shape, (2, 10, 10))
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 10))
|
||||||
|
|
||||||
|
batch = data_collator(no_pad_features)
|
||||||
|
self.assertIsInstance(batch, dict)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
||||||
|
self.assertEqual(batch["perm_mask"].shape, (2, 10, 10))
|
||||||
|
self.assertEqual(batch["target_mapping"].shape, (2, 10, 10))
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 10))
|
||||||
|
|
||||||
|
example = [np.random.randint(0, 5, [5])]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# Expect error due to odd sequence length
|
||||||
|
data_collator(example)
|
||||||
|
|
||||||
|
def test_nsp(self):
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
features = [
|
||||||
|
{"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i}
|
||||||
|
for i in range(2)
|
||||||
|
]
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, return_tensors="np")
|
||||||
|
batch = data_collator(features)
|
||||||
|
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 5))
|
||||||
|
self.assertEqual(batch["token_type_ids"].shape, (2, 5))
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 5))
|
||||||
|
self.assertEqual(batch["next_sentence_label"].shape, (2,))
|
||||||
|
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8, return_tensors="np")
|
||||||
|
batch = data_collator(features)
|
||||||
|
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 8))
|
||||||
|
self.assertEqual(batch["token_type_ids"].shape, (2, 8))
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 8))
|
||||||
|
self.assertEqual(batch["next_sentence_label"].shape, (2,))
|
||||||
|
|
||||||
|
def test_sop(self):
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
features = [
|
||||||
|
{
|
||||||
|
"input_ids": np.array([0, 1, 2, 3, 4]),
|
||||||
|
"token_type_ids": np.array([0, 1, 2, 3, 4]),
|
||||||
|
"sentence_order_label": i,
|
||||||
|
}
|
||||||
|
for i in range(2)
|
||||||
|
]
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, return_tensors="np")
|
||||||
|
batch = data_collator(features)
|
||||||
|
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 5))
|
||||||
|
self.assertEqual(batch["token_type_ids"].shape, (2, 5))
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 5))
|
||||||
|
self.assertEqual(batch["sentence_order_label"].shape, (2,))
|
||||||
|
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8, return_tensors="np")
|
||||||
|
batch = data_collator(features)
|
||||||
|
|
||||||
|
self.assertEqual(batch["input_ids"].shape, (2, 8))
|
||||||
|
self.assertEqual(batch["token_type_ids"].shape, (2, 8))
|
||||||
|
self.assertEqual(batch["labels"].shape, (2, 8))
|
||||||
|
self.assertEqual(batch["sentence_order_label"].shape, (2,))
|
||||||
|
Loading…
Reference in New Issue
Block a user