mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Add strategy to store results in evaluation loop (#30267)
* Add evaluation loop container for interm. results * Add tests for EvalLoopContainer * Formatting * Fix padding_index in test and typo * Move EvalLoopContainer to pr_utils to avoid additional imports * Fix `eval_do_concat_batches` arg description * Fix EvalLoopContainer import
This commit is contained in:
parent
8d6b509611
commit
c15aad0939
@ -82,6 +82,7 @@ from .trainer_callback import (
|
||||
)
|
||||
from .trainer_pt_utils import (
|
||||
DistributedTensorGatherer,
|
||||
EvalLoopContainer,
|
||||
IterableDatasetShard,
|
||||
LabelSmoother,
|
||||
LayerWiseDummyOptimizer,
|
||||
@ -3627,20 +3628,14 @@ class Trainer:
|
||||
self._past = None
|
||||
|
||||
# Initialize containers
|
||||
# losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
|
||||
losses_host = None
|
||||
preds_host = None
|
||||
labels_host = None
|
||||
inputs_host = None
|
||||
all_losses = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
|
||||
all_preds = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
|
||||
all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
|
||||
all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
|
||||
|
||||
# losses/preds/labels on CPU (final containers)
|
||||
all_losses = None
|
||||
all_preds = None
|
||||
all_labels = None
|
||||
all_inputs = None
|
||||
# Will be useful when we have an iterable dataset so don't know its length.
|
||||
|
||||
observed_num_examples = 0
|
||||
|
||||
# Main evaluation loop
|
||||
for step, inputs in enumerate(dataloader):
|
||||
# Update the observed num examples
|
||||
@ -3659,56 +3654,33 @@ class Trainer:
|
||||
if is_torch_xla_available():
|
||||
xm.mark_step()
|
||||
|
||||
# Update containers on host
|
||||
# Update containers
|
||||
if loss is not None:
|
||||
losses = self.gather_function((loss.repeat(batch_size)))
|
||||
losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
|
||||
if labels is not None:
|
||||
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
|
||||
all_losses.add(losses)
|
||||
if inputs_decode is not None:
|
||||
inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
|
||||
inputs_decode = self.gather_function((inputs_decode))
|
||||
inputs_host = (
|
||||
inputs_decode
|
||||
if inputs_host is None
|
||||
else nested_concat(inputs_host, inputs_decode, padding_index=-100)
|
||||
)
|
||||
all_inputs.add(inputs_decode)
|
||||
if logits is not None:
|
||||
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
|
||||
if self.preprocess_logits_for_metrics is not None:
|
||||
logits = self.preprocess_logits_for_metrics(logits, labels)
|
||||
logits = self.gather_function((logits))
|
||||
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
|
||||
|
||||
all_preds.add(logits)
|
||||
if labels is not None:
|
||||
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
|
||||
labels = self.gather_function((labels))
|
||||
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
|
||||
all_labels.add(labels)
|
||||
|
||||
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
|
||||
|
||||
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
||||
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
|
||||
if losses_host is not None:
|
||||
losses = nested_numpify(losses_host)
|
||||
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
|
||||
if preds_host is not None:
|
||||
logits = nested_numpify(preds_host)
|
||||
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
|
||||
if inputs_host is not None:
|
||||
inputs_decode = nested_numpify(inputs_host)
|
||||
all_inputs = (
|
||||
inputs_decode
|
||||
if all_inputs is None
|
||||
else nested_concat(all_inputs, inputs_decode, padding_index=-100)
|
||||
)
|
||||
if labels_host is not None:
|
||||
labels = nested_numpify(labels_host)
|
||||
all_labels = (
|
||||
labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
|
||||
)
|
||||
|
||||
# Set back to None to begin a new accumulation
|
||||
losses_host, preds_host, inputs_host, labels_host = None, None, None, None
|
||||
all_losses.to_cpu_and_numpy()
|
||||
all_preds.to_cpu_and_numpy()
|
||||
all_labels.to_cpu_and_numpy()
|
||||
all_inputs.to_cpu_and_numpy()
|
||||
|
||||
# After all calls to `.gather_function`, reset to `gather_for_metrics`:
|
||||
self.gather_function = self.accelerator.gather_for_metrics
|
||||
@ -3717,20 +3689,10 @@ class Trainer:
|
||||
delattr(self, "_past")
|
||||
|
||||
# Gather all remaining tensors and put them back on the CPU
|
||||
if losses_host is not None:
|
||||
losses = nested_numpify(losses_host)
|
||||
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
|
||||
if preds_host is not None:
|
||||
logits = nested_numpify(preds_host)
|
||||
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
|
||||
if inputs_host is not None:
|
||||
inputs_decode = nested_numpify(inputs_host)
|
||||
all_inputs = (
|
||||
inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
|
||||
)
|
||||
if labels_host is not None:
|
||||
labels = nested_numpify(labels_host)
|
||||
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
|
||||
all_losses = all_losses.get_arrays()
|
||||
all_preds = all_preds.get_arrays()
|
||||
all_labels = all_labels.get_arrays()
|
||||
all_inputs = all_inputs.get_arrays()
|
||||
|
||||
# Number of samples
|
||||
if has_length(eval_dataset):
|
||||
@ -3761,7 +3723,9 @@ class Trainer:
|
||||
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
|
||||
metrics = denumpify_detensorize(metrics)
|
||||
|
||||
if all_losses is not None:
|
||||
if isinstance(all_losses, list) and all_losses:
|
||||
metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()
|
||||
elif isinstance(all_losses, np.ndarray):
|
||||
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
|
||||
if hasattr(self, "jit_compilation_time"):
|
||||
metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
|
||||
@ -4204,6 +4168,7 @@ class Trainer:
|
||||
logger.info(f"***** Running {description} *****")
|
||||
logger.info(f" Num examples = {num_examples}")
|
||||
logger.info(f" Batch size = {batch_size}")
|
||||
|
||||
losses_host: torch.Tensor = None
|
||||
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
||||
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
||||
|
@ -299,6 +299,58 @@ class DistributedSamplerWithLoop(DistributedSampler):
|
||||
return iter(indices)
|
||||
|
||||
|
||||
class EvalLoopContainer:
|
||||
"""
|
||||
Container to store intermediate results of evaluation loop
|
||||
|
||||
Args:
|
||||
do_nested_concat (`bool`, *optional*, defaults to `True`):
|
||||
If set to `True`, each iteration will recursively concatenate a new object containing tensors to
|
||||
the existing stored tensors, provided that the structure of the existing object and the new one
|
||||
are identical. If set to `False`, all newly added tensors will be stored in a list.
|
||||
padding_index (`int`, *optional*, defaults to -100):
|
||||
Value used to pad tensors of different shapes when `do_nested_concat=True`.
|
||||
"""
|
||||
|
||||
def __init__(self, do_nested_concat: bool = True, padding_index: int = -100):
|
||||
self.do_nested_concat = do_nested_concat
|
||||
self.padding_index = padding_index
|
||||
self.tensors = None
|
||||
self.arrays = None
|
||||
|
||||
def add(self, tensors) -> None:
|
||||
"""Add tensors to the stored objects. If `do_nested_concat=True`, the tensors will be concatenated recursively."""
|
||||
if self.tensors is None:
|
||||
self.tensors = tensors if self.do_nested_concat else [tensors]
|
||||
elif self.do_nested_concat:
|
||||
self.tensors = nested_concat(self.tensors, tensors, padding_index=self.padding_index)
|
||||
else:
|
||||
self.tensors.append(tensors)
|
||||
|
||||
def to_cpu_and_numpy(self) -> None:
|
||||
"""Move tensors in stored objects to CPU and convert them to numpy arrays."""
|
||||
|
||||
# Check if we have something to add, if not just return
|
||||
if self.tensors is None:
|
||||
return
|
||||
|
||||
new_arrays = nested_numpify(self.tensors)
|
||||
if self.arrays is None:
|
||||
self.arrays = new_arrays
|
||||
elif self.do_nested_concat:
|
||||
self.arrays = nested_concat(self.arrays, new_arrays, padding_index=self.padding_index)
|
||||
else:
|
||||
self.arrays.extend(new_arrays)
|
||||
|
||||
# reset device tensors after adding to cpu
|
||||
self.tensors = None
|
||||
|
||||
def get_arrays(self):
|
||||
"""Returns the numpified and moved to CPU stored objects."""
|
||||
self.to_cpu_and_numpy()
|
||||
return self.arrays
|
||||
|
||||
|
||||
class SequentialDistributedSampler(Sampler):
|
||||
"""
|
||||
Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end.
|
||||
|
@ -670,6 +670,9 @@ class TrainingArguments:
|
||||
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
|
||||
that need inputs, predictions and references for scoring calculation in Metric class.
|
||||
eval_do_concat_batches (`bool`, *optional*, defaults to `True`):
|
||||
Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`,
|
||||
will instead store them as lists, with each batch kept separate.
|
||||
auto_find_batch_size (`bool`, *optional*, defaults to `False`)
|
||||
Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding
|
||||
CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
|
||||
@ -1289,6 +1292,12 @@ class TrainingArguments:
|
||||
include_inputs_for_metrics: bool = field(
|
||||
default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."}
|
||||
)
|
||||
eval_do_concat_batches: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`, will instead store them as lists, with each batch kept separate."
|
||||
},
|
||||
)
|
||||
# Deprecated arguments
|
||||
fp16_backend: str = field(
|
||||
default="auto",
|
||||
|
@ -35,6 +35,7 @@ if is_torch_available():
|
||||
DistributedLengthGroupedSampler,
|
||||
DistributedSamplerWithLoop,
|
||||
DistributedTensorGatherer,
|
||||
EvalLoopContainer,
|
||||
IterableDatasetShard,
|
||||
LabelSmoother,
|
||||
LengthGroupedSampler,
|
||||
@ -497,3 +498,92 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
remove_columns_collator(data_batch)
|
||||
self.assertEqual(logger.called, 1)
|
||||
self.assertIn("col3", logger.last_msg)
|
||||
|
||||
def test_eval_loop_container(self):
|
||||
batch_1 = [
|
||||
torch.ones([8, 5]),
|
||||
{"loss": torch.tensor(1.0)},
|
||||
(torch.ones([8, 2, 3]), torch.ones([8, 2])),
|
||||
]
|
||||
batch_2 = [
|
||||
torch.ones([4, 5]),
|
||||
{"loss": torch.tensor(2.0)},
|
||||
(torch.ones([4, 2, 3]), torch.ones([4, 6])),
|
||||
]
|
||||
|
||||
concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100)
|
||||
concat_container.add(batch_1)
|
||||
concat_container.add(batch_2)
|
||||
concat_container.to_cpu_and_numpy()
|
||||
arrays = concat_container.get_arrays()
|
||||
|
||||
# Test two nested batches concatenation
|
||||
self.assertIsInstance(arrays, list)
|
||||
self.assertEqual(len(arrays), 3)
|
||||
self.assertIsInstance(arrays[0], np.ndarray)
|
||||
self.assertEqual(arrays[0].shape, (12, 5))
|
||||
self.assertIsInstance(arrays[1], dict)
|
||||
self.assertIsInstance(arrays[1]["loss"], np.ndarray)
|
||||
self.assertEqual(arrays[1]["loss"].shape, (2,))
|
||||
self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1.0, 2.0])))
|
||||
self.assertIsInstance(arrays[2], tuple)
|
||||
self.assertEqual(len(arrays[2]), 2)
|
||||
self.assertEqual(arrays[2][0].shape, (12, 2, 3))
|
||||
self.assertEqual(arrays[2][1].shape, (12, 6))
|
||||
# check that first batch padded with padding index -100 after concatenation
|
||||
self.assertEqual(arrays[2][1][0][2], -100)
|
||||
|
||||
# Test two batches with no concatenation
|
||||
list_container = EvalLoopContainer(do_nested_concat=False)
|
||||
list_container.add(batch_1)
|
||||
list_container.add(batch_2)
|
||||
list_container.to_cpu_and_numpy()
|
||||
arrays = list_container.get_arrays()
|
||||
|
||||
self.assertEqual(len(arrays), 2)
|
||||
self.assertIsInstance(arrays, list)
|
||||
np_batch_1, np_batch_2 = arrays
|
||||
|
||||
self.assertIsInstance(np_batch_1, list)
|
||||
self.assertEqual(len(np_batch_1), 3)
|
||||
self.assertIsInstance(np_batch_1[0], np.ndarray)
|
||||
self.assertIsInstance(np_batch_1[1], dict)
|
||||
self.assertIsInstance(np_batch_1[2], tuple)
|
||||
self.assertEqual(np_batch_1[0].shape, (8, 5))
|
||||
self.assertEqual(np_batch_1[1]["loss"].shape, ())
|
||||
self.assertEqual(np_batch_1[2][0].shape, (8, 2, 3))
|
||||
self.assertEqual(np_batch_1[2][1].shape, (8, 2))
|
||||
|
||||
self.assertIsInstance(np_batch_2, list)
|
||||
self.assertEqual(len(np_batch_2), 3)
|
||||
self.assertIsInstance(np_batch_2[0], np.ndarray)
|
||||
self.assertIsInstance(np_batch_2[1], dict)
|
||||
self.assertIsInstance(np_batch_2[2], tuple)
|
||||
self.assertEqual(np_batch_2[0].shape, (4, 5))
|
||||
self.assertEqual(np_batch_2[1]["loss"].shape, ())
|
||||
self.assertEqual(np_batch_2[2][0].shape, (4, 2, 3))
|
||||
self.assertEqual(np_batch_2[2][1].shape, (4, 6))
|
||||
|
||||
# Test no batches
|
||||
none_arr = EvalLoopContainer(do_nested_concat=True, padding_index=-100).get_arrays()
|
||||
self.assertIsNone(none_arr)
|
||||
|
||||
none_arr = EvalLoopContainer(do_nested_concat=False).get_arrays()
|
||||
self.assertIsNone(none_arr)
|
||||
|
||||
# Test one batch
|
||||
concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100)
|
||||
concat_container.add(batch_1)
|
||||
arrays = concat_container.get_arrays()
|
||||
self.assertIsInstance(arrays, list)
|
||||
self.assertEqual(len(arrays), 3)
|
||||
self.assertIsInstance(arrays[0], np.ndarray)
|
||||
self.assertEqual(arrays[0].shape, (8, 5))
|
||||
self.assertIsInstance(arrays[1], dict)
|
||||
self.assertIsInstance(arrays[1]["loss"], np.ndarray)
|
||||
self.assertEqual(arrays[1]["loss"].shape, ())
|
||||
self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1.0])))
|
||||
self.assertIsInstance(arrays[2], tuple)
|
||||
self.assertEqual(len(arrays[2]), 2)
|
||||
self.assertEqual(arrays[2][0].shape, (8, 2, 3))
|
||||
self.assertEqual(arrays[2][1].shape, (8, 2))
|
||||
|
Loading…
Reference in New Issue
Block a user