Add strategy to store results in evaluation loop (#30267)

* Add evaluation loop container for interm. results

* Add tests for EvalLoopContainer

* Formatting

* Fix padding_index in test and typo

* Move EvalLoopContainer to pr_utils to avoid additional imports

* Fix `eval_do_concat_batches` arg description

* Fix EvalLoopContainer import
This commit is contained in:
Pavel Iakubovskii 2024-04-17 12:42:27 +01:00 committed by GitHub
parent 8d6b509611
commit c15aad0939
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 175 additions and 59 deletions

View File

@ -82,6 +82,7 @@ from .trainer_callback import (
) )
from .trainer_pt_utils import ( from .trainer_pt_utils import (
DistributedTensorGatherer, DistributedTensorGatherer,
EvalLoopContainer,
IterableDatasetShard, IterableDatasetShard,
LabelSmoother, LabelSmoother,
LayerWiseDummyOptimizer, LayerWiseDummyOptimizer,
@ -3627,20 +3628,14 @@ class Trainer:
self._past = None self._past = None
# Initialize containers # Initialize containers
# losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps) all_losses = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
losses_host = None all_preds = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
preds_host = None all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
labels_host = None all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
inputs_host = None
# losses/preds/labels on CPU (final containers)
all_losses = None
all_preds = None
all_labels = None
all_inputs = None
# Will be useful when we have an iterable dataset so don't know its length. # Will be useful when we have an iterable dataset so don't know its length.
observed_num_examples = 0 observed_num_examples = 0
# Main evaluation loop # Main evaluation loop
for step, inputs in enumerate(dataloader): for step, inputs in enumerate(dataloader):
# Update the observed num examples # Update the observed num examples
@ -3659,56 +3654,33 @@ class Trainer:
if is_torch_xla_available(): if is_torch_xla_available():
xm.mark_step() xm.mark_step()
# Update containers on host # Update containers
if loss is not None: if loss is not None:
losses = self.gather_function((loss.repeat(batch_size))) losses = self.gather_function((loss.repeat(batch_size)))
losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100) all_losses.add(losses)
if labels is not None:
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
if inputs_decode is not None: if inputs_decode is not None:
inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
inputs_decode = self.gather_function((inputs_decode)) inputs_decode = self.gather_function((inputs_decode))
inputs_host = ( all_inputs.add(inputs_decode)
inputs_decode
if inputs_host is None
else nested_concat(inputs_host, inputs_decode, padding_index=-100)
)
if logits is not None: if logits is not None:
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
if self.preprocess_logits_for_metrics is not None: if self.preprocess_logits_for_metrics is not None:
logits = self.preprocess_logits_for_metrics(logits, labels) logits = self.preprocess_logits_for_metrics(logits, labels)
logits = self.gather_function((logits)) logits = self.gather_function((logits))
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) all_preds.add(logits)
if labels is not None: if labels is not None:
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
labels = self.gather_function((labels)) labels = self.gather_function((labels))
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) all_labels.add(labels)
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps. # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
if losses_host is not None: all_losses.to_cpu_and_numpy()
losses = nested_numpify(losses_host) all_preds.to_cpu_and_numpy()
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) all_labels.to_cpu_and_numpy()
if preds_host is not None: all_inputs.to_cpu_and_numpy()
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
inputs_decode
if all_inputs is None
else nested_concat(all_inputs, inputs_decode, padding_index=-100)
)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = (
labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
)
# Set back to None to begin a new accumulation
losses_host, preds_host, inputs_host, labels_host = None, None, None, None
# After all calls to `.gather_function`, reset to `gather_for_metrics`: # After all calls to `.gather_function`, reset to `gather_for_metrics`:
self.gather_function = self.accelerator.gather_for_metrics self.gather_function = self.accelerator.gather_for_metrics
@ -3717,20 +3689,10 @@ class Trainer:
delattr(self, "_past") delattr(self, "_past")
# Gather all remaining tensors and put them back on the CPU # Gather all remaining tensors and put them back on the CPU
if losses_host is not None: all_losses = all_losses.get_arrays()
losses = nested_numpify(losses_host) all_preds = all_preds.get_arrays()
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) all_labels = all_labels.get_arrays()
if preds_host is not None: all_inputs = all_inputs.get_arrays()
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
# Number of samples # Number of samples
if has_length(eval_dataset): if has_length(eval_dataset):
@ -3761,7 +3723,9 @@ class Trainer:
# To be JSON-serializable, we need to remove numpy types or zero-d tensors # To be JSON-serializable, we need to remove numpy types or zero-d tensors
metrics = denumpify_detensorize(metrics) metrics = denumpify_detensorize(metrics)
if all_losses is not None: if isinstance(all_losses, list) and all_losses:
metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()
elif isinstance(all_losses, np.ndarray):
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
if hasattr(self, "jit_compilation_time"): if hasattr(self, "jit_compilation_time"):
metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
@ -4204,6 +4168,7 @@ class Trainer:
logger.info(f"***** Running {description} *****") logger.info(f"***** Running {description} *****")
logger.info(f" Num examples = {num_examples}") logger.info(f" Num examples = {num_examples}")
logger.info(f" Batch size = {batch_size}") logger.info(f" Batch size = {batch_size}")
losses_host: torch.Tensor = None losses_host: torch.Tensor = None
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None labels_host: Union[torch.Tensor, List[torch.Tensor]] = None

View File

@ -299,6 +299,58 @@ class DistributedSamplerWithLoop(DistributedSampler):
return iter(indices) return iter(indices)
class EvalLoopContainer:
"""
Container to store intermediate results of evaluation loop
Args:
do_nested_concat (`bool`, *optional*, defaults to `True`):
If set to `True`, each iteration will recursively concatenate a new object containing tensors to
the existing stored tensors, provided that the structure of the existing object and the new one
are identical. If set to `False`, all newly added tensors will be stored in a list.
padding_index (`int`, *optional*, defaults to -100):
Value used to pad tensors of different shapes when `do_nested_concat=True`.
"""
def __init__(self, do_nested_concat: bool = True, padding_index: int = -100):
self.do_nested_concat = do_nested_concat
self.padding_index = padding_index
self.tensors = None
self.arrays = None
def add(self, tensors) -> None:
"""Add tensors to the stored objects. If `do_nested_concat=True`, the tensors will be concatenated recursively."""
if self.tensors is None:
self.tensors = tensors if self.do_nested_concat else [tensors]
elif self.do_nested_concat:
self.tensors = nested_concat(self.tensors, tensors, padding_index=self.padding_index)
else:
self.tensors.append(tensors)
def to_cpu_and_numpy(self) -> None:
"""Move tensors in stored objects to CPU and convert them to numpy arrays."""
# Check if we have something to add, if not just return
if self.tensors is None:
return
new_arrays = nested_numpify(self.tensors)
if self.arrays is None:
self.arrays = new_arrays
elif self.do_nested_concat:
self.arrays = nested_concat(self.arrays, new_arrays, padding_index=self.padding_index)
else:
self.arrays.extend(new_arrays)
# reset device tensors after adding to cpu
self.tensors = None
def get_arrays(self):
"""Returns the numpified and moved to CPU stored objects."""
self.to_cpu_and_numpy()
return self.arrays
class SequentialDistributedSampler(Sampler): class SequentialDistributedSampler(Sampler):
""" """
Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end. Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end.

View File

@ -670,6 +670,9 @@ class TrainingArguments:
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`): include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
that need inputs, predictions and references for scoring calculation in Metric class. that need inputs, predictions and references for scoring calculation in Metric class.
eval_do_concat_batches (`bool`, *optional*, defaults to `True`):
Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`,
will instead store them as lists, with each batch kept separate.
auto_find_batch_size (`bool`, *optional*, defaults to `False`) auto_find_batch_size (`bool`, *optional*, defaults to `False`)
Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding
CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`) CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
@ -1289,6 +1292,12 @@ class TrainingArguments:
include_inputs_for_metrics: bool = field( include_inputs_for_metrics: bool = field(
default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."} default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."}
) )
eval_do_concat_batches: bool = field(
default=True,
metadata={
"help": "Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`, will instead store them as lists, with each batch kept separate."
},
)
# Deprecated arguments # Deprecated arguments
fp16_backend: str = field( fp16_backend: str = field(
default="auto", default="auto",

View File

@ -35,6 +35,7 @@ if is_torch_available():
DistributedLengthGroupedSampler, DistributedLengthGroupedSampler,
DistributedSamplerWithLoop, DistributedSamplerWithLoop,
DistributedTensorGatherer, DistributedTensorGatherer,
EvalLoopContainer,
IterableDatasetShard, IterableDatasetShard,
LabelSmoother, LabelSmoother,
LengthGroupedSampler, LengthGroupedSampler,
@ -497,3 +498,92 @@ class TrainerUtilsTest(unittest.TestCase):
remove_columns_collator(data_batch) remove_columns_collator(data_batch)
self.assertEqual(logger.called, 1) self.assertEqual(logger.called, 1)
self.assertIn("col3", logger.last_msg) self.assertIn("col3", logger.last_msg)
def test_eval_loop_container(self):
batch_1 = [
torch.ones([8, 5]),
{"loss": torch.tensor(1.0)},
(torch.ones([8, 2, 3]), torch.ones([8, 2])),
]
batch_2 = [
torch.ones([4, 5]),
{"loss": torch.tensor(2.0)},
(torch.ones([4, 2, 3]), torch.ones([4, 6])),
]
concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100)
concat_container.add(batch_1)
concat_container.add(batch_2)
concat_container.to_cpu_and_numpy()
arrays = concat_container.get_arrays()
# Test two nested batches concatenation
self.assertIsInstance(arrays, list)
self.assertEqual(len(arrays), 3)
self.assertIsInstance(arrays[0], np.ndarray)
self.assertEqual(arrays[0].shape, (12, 5))
self.assertIsInstance(arrays[1], dict)
self.assertIsInstance(arrays[1]["loss"], np.ndarray)
self.assertEqual(arrays[1]["loss"].shape, (2,))
self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1.0, 2.0])))
self.assertIsInstance(arrays[2], tuple)
self.assertEqual(len(arrays[2]), 2)
self.assertEqual(arrays[2][0].shape, (12, 2, 3))
self.assertEqual(arrays[2][1].shape, (12, 6))
# check that first batch padded with padding index -100 after concatenation
self.assertEqual(arrays[2][1][0][2], -100)
# Test two batches with no concatenation
list_container = EvalLoopContainer(do_nested_concat=False)
list_container.add(batch_1)
list_container.add(batch_2)
list_container.to_cpu_and_numpy()
arrays = list_container.get_arrays()
self.assertEqual(len(arrays), 2)
self.assertIsInstance(arrays, list)
np_batch_1, np_batch_2 = arrays
self.assertIsInstance(np_batch_1, list)
self.assertEqual(len(np_batch_1), 3)
self.assertIsInstance(np_batch_1[0], np.ndarray)
self.assertIsInstance(np_batch_1[1], dict)
self.assertIsInstance(np_batch_1[2], tuple)
self.assertEqual(np_batch_1[0].shape, (8, 5))
self.assertEqual(np_batch_1[1]["loss"].shape, ())
self.assertEqual(np_batch_1[2][0].shape, (8, 2, 3))
self.assertEqual(np_batch_1[2][1].shape, (8, 2))
self.assertIsInstance(np_batch_2, list)
self.assertEqual(len(np_batch_2), 3)
self.assertIsInstance(np_batch_2[0], np.ndarray)
self.assertIsInstance(np_batch_2[1], dict)
self.assertIsInstance(np_batch_2[2], tuple)
self.assertEqual(np_batch_2[0].shape, (4, 5))
self.assertEqual(np_batch_2[1]["loss"].shape, ())
self.assertEqual(np_batch_2[2][0].shape, (4, 2, 3))
self.assertEqual(np_batch_2[2][1].shape, (4, 6))
# Test no batches
none_arr = EvalLoopContainer(do_nested_concat=True, padding_index=-100).get_arrays()
self.assertIsNone(none_arr)
none_arr = EvalLoopContainer(do_nested_concat=False).get_arrays()
self.assertIsNone(none_arr)
# Test one batch
concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100)
concat_container.add(batch_1)
arrays = concat_container.get_arrays()
self.assertIsInstance(arrays, list)
self.assertEqual(len(arrays), 3)
self.assertIsInstance(arrays[0], np.ndarray)
self.assertEqual(arrays[0].shape, (8, 5))
self.assertIsInstance(arrays[1], dict)
self.assertIsInstance(arrays[1]["loss"], np.ndarray)
self.assertEqual(arrays[1]["loss"].shape, ())
self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1.0])))
self.assertIsInstance(arrays[2], tuple)
self.assertEqual(len(arrays[2]), 2)
self.assertEqual(arrays[2][0].shape, (8, 2, 3))
self.assertEqual(arrays[2][1].shape, (8, 2))