mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Make default_data_collator more flexible and deprecate old behavior (#5060)
* Make default_data_collator more flexible * Accept tensors for all features * Document code * Refactor * Formatting
This commit is contained in:
parent
5e06963394
commit
20fa828984
@ -33,31 +33,34 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
|
|||||||
# have the same attributes.
|
# have the same attributes.
|
||||||
# So we will look at the first element as a proxy for what attributes exist
|
# So we will look at the first element as a proxy for what attributes exist
|
||||||
# on the whole batch.
|
# on the whole batch.
|
||||||
|
if not isinstance(features[0], dict):
|
||||||
|
features = [vars(f) for f in features]
|
||||||
|
|
||||||
first = features[0]
|
first = features[0]
|
||||||
|
batch = {}
|
||||||
|
|
||||||
# Special handling for labels.
|
# Special handling for labels.
|
||||||
# Ensure that tensor is created with the correct type
|
# Ensure that tensor is created with the correct type
|
||||||
# (it should be automatically the case, but let's make sure of it.)
|
# (it should be automatically the case, but let's make sure of it.)
|
||||||
if hasattr(first, "label") and first.label is not None:
|
if "label" in first:
|
||||||
if type(first.label) is int:
|
dtype = torch.long if type(first["label"]) is int else torch.float
|
||||||
labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
||||||
|
elif "label_ids" in first:
|
||||||
|
if isinstance(first["label_ids"], torch.Tensor):
|
||||||
|
batch["labels"] = torch.stack([f["label_ids"] for f in features])
|
||||||
else:
|
else:
|
||||||
labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
|
||||||
batch = {"labels": labels}
|
batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
|
||||||
elif hasattr(first, "label_ids") and first.label_ids is not None:
|
|
||||||
if type(first.label_ids[0]) is int:
|
|
||||||
labels = torch.tensor([f.label_ids for f in features], dtype=torch.long)
|
|
||||||
else:
|
|
||||||
labels = torch.tensor([f.label_ids for f in features], dtype=torch.float)
|
|
||||||
batch = {"labels": labels}
|
|
||||||
else:
|
|
||||||
batch = {}
|
|
||||||
|
|
||||||
# Handling of all other possible attributes.
|
# Handling of all other possible keys.
|
||||||
# Again, we will use the first element to figure out which key/values are not None for this model.
|
# Again, we will use the first element to figure out which key/values are not None for this model.
|
||||||
for k, v in vars(first).items():
|
for k, v in first.items():
|
||||||
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
||||||
batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long)
|
if isinstance(v, torch.Tensor):
|
||||||
|
batch[k] = torch.stack([f[k] for f in features])
|
||||||
|
else:
|
||||||
|
batch[k] = torch.tensor([f[k] for f in features], dtype=torch.long)
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, List, Optional, Tuple
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
@ -205,6 +206,15 @@ class Trainer:
|
|||||||
# Set an xla_device flag on the model's config.
|
# Set an xla_device flag on the model's config.
|
||||||
# We'll find a more elegant and not need to do this in the future.
|
# We'll find a more elegant and not need to do this in the future.
|
||||||
self.model.config.xla_device = True
|
self.model.config.xla_device = True
|
||||||
|
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
|
||||||
|
self.data_collator = self.data_collator.collate_batch
|
||||||
|
warnings.warn(
|
||||||
|
(
|
||||||
|
"The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
|
||||||
|
+ "with a `collate_batch` are deprecated and won't be supported in a future version."
|
||||||
|
),
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
if self.train_dataset is None:
|
if self.train_dataset is None:
|
||||||
|
@ -24,6 +24,27 @@ PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class DataCollatorIntegrationTest(unittest.TestCase):
|
class DataCollatorIntegrationTest(unittest.TestCase):
|
||||||
|
def test_default_with_dict(self):
|
||||||
|
features = [{"labels": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
||||||
|
batch = default_data_collator(features)
|
||||||
|
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
|
||||||
|
self.assertEqual(batch["labels"].dtype, torch.long)
|
||||||
|
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
|
||||||
|
|
||||||
|
# With label_ids
|
||||||
|
features = [{"label_ids": [0, 1, 2], "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
||||||
|
batch = default_data_collator(features)
|
||||||
|
self.assertTrue(batch["labels"].equal(torch.tensor([[0, 1, 2]] * 8)))
|
||||||
|
self.assertEqual(batch["labels"].dtype, torch.long)
|
||||||
|
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
|
||||||
|
|
||||||
|
# Features can already be tensors
|
||||||
|
features = [{"labels": i, "inputs": torch.randint(10, [10])} for i in range(8)]
|
||||||
|
batch = default_data_collator(features)
|
||||||
|
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
|
||||||
|
self.assertEqual(batch["labels"].dtype, torch.long)
|
||||||
|
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
|
||||||
|
|
||||||
def test_default_classification(self):
|
def test_default_classification(self):
|
||||||
MODEL_ID = "bert-base-cased-finetuned-mrpc"
|
MODEL_ID = "bert-base-cased-finetuned-mrpc"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||||
|
Loading…
Reference in New Issue
Block a user