Make default_data_collator more flexible and deprecate old behavior (#5060)

* Make default_data_collator more flexible

* Accept tensors for all features

* Document code

* Refactor

* Formatting
This commit is contained in:
Sylvain Gugger 2020-06-17 15:24:51 -04:00 committed by GitHub
parent 5e06963394
commit 20fa828984
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 50 additions and 16 deletions

View File

@ -33,31 +33,34 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
# have the same attributes. # have the same attributes.
# So we will look at the first element as a proxy for what attributes exist # So we will look at the first element as a proxy for what attributes exist
# on the whole batch. # on the whole batch.
if not isinstance(features[0], dict):
features = [vars(f) for f in features]
first = features[0] first = features[0]
batch = {}
# Special handling for labels. # Special handling for labels.
# Ensure that tensor is created with the correct type # Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.) # (it should be automatically the case, but let's make sure of it.)
if hasattr(first, "label") and first.label is not None: if "label" in first:
if type(first.label) is int: dtype = torch.long if type(first["label"]) is int else torch.float
labels = torch.tensor([f.label for f in features], dtype=torch.long) batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
elif "label_ids" in first:
if isinstance(first["label_ids"], torch.Tensor):
batch["labels"] = torch.stack([f["label_ids"] for f in features])
else: else:
labels = torch.tensor([f.label for f in features], dtype=torch.float) dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
batch = {"labels": labels} batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
elif hasattr(first, "label_ids") and first.label_ids is not None:
if type(first.label_ids[0]) is int:
labels = torch.tensor([f.label_ids for f in features], dtype=torch.long)
else:
labels = torch.tensor([f.label_ids for f in features], dtype=torch.float)
batch = {"labels": labels}
else:
batch = {}
# Handling of all other possible attributes. # Handling of all other possible keys.
# Again, we will use the first element to figure out which key/values are not None for this model. # Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in vars(first).items(): for k, v in first.items():
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str): if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long) if isinstance(v, torch.Tensor):
batch[k] = torch.stack([f[k] for f in features])
else:
batch[k] = torch.tensor([f[k] for f in features], dtype=torch.long)
return batch return batch

View File

@ -4,6 +4,7 @@ import os
import random import random
import re import re
import shutil import shutil
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Tuple
@ -205,6 +206,15 @@ class Trainer:
# Set an xla_device flag on the model's config. # Set an xla_device flag on the model's config.
# We'll find a more elegant and not need to do this in the future. # We'll find a more elegant and not need to do this in the future.
self.model.config.xla_device = True self.model.config.xla_device = True
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
self.data_collator = self.data_collator.collate_batch
warnings.warn(
(
"The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
+ "with a `collate_batch` are deprecated and won't be supported in a future version."
),
FutureWarning,
)
def get_train_dataloader(self) -> DataLoader: def get_train_dataloader(self) -> DataLoader:
if self.train_dataset is None: if self.train_dataset is None:

View File

@ -24,6 +24,27 @@ PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
@require_torch @require_torch
class DataCollatorIntegrationTest(unittest.TestCase): class DataCollatorIntegrationTest(unittest.TestCase):
def test_default_with_dict(self):
features = [{"labels": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
batch = default_data_collator(features)
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
# With label_ids
features = [{"label_ids": [0, 1, 2], "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
batch = default_data_collator(features)
self.assertTrue(batch["labels"].equal(torch.tensor([[0, 1, 2]] * 8)))
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
# Features can already be tensors
features = [{"labels": i, "inputs": torch.randint(10, [10])} for i in range(8)]
batch = default_data_collator(features)
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
def test_default_classification(self): def test_default_classification(self):
MODEL_ID = "bert-base-cased-finetuned-mrpc" MODEL_ID = "bert-base-cased-finetuned-mrpc"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)