mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Cleanup Trainer and expose customization points (#5982)
* Clean up Trainer and expose customization points * Formatting * eval_step -> prediction_step
This commit is contained in:
parent
76f52324b1
commit
e168488a74
@ -202,7 +202,7 @@ class Trainer:
|
|||||||
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
|
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
|
||||||
)
|
)
|
||||||
if is_wandb_available():
|
if is_wandb_available():
|
||||||
self._setup_wandb()
|
self.setup_wandb()
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
|
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
|
||||||
@ -226,23 +226,32 @@ class Trainer:
|
|||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||||
"""
|
|
||||||
Returns the training :class:`~torch.utils.data.DataLoader`.
|
|
||||||
"""
|
|
||||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||||
train_sampler = None
|
return None
|
||||||
elif self.train_dataset is None:
|
|
||||||
raise ValueError("Trainer: training requires a train_dataset.")
|
|
||||||
elif is_torch_tpu_available():
|
elif is_torch_tpu_available():
|
||||||
train_sampler = get_tpu_sampler(self.train_dataset)
|
return get_tpu_sampler(self.train_dataset)
|
||||||
else:
|
else:
|
||||||
train_sampler = (
|
return (
|
||||||
RandomSampler(self.train_dataset)
|
RandomSampler(self.train_dataset)
|
||||||
if self.args.local_rank == -1
|
if self.args.local_rank == -1
|
||||||
else DistributedSampler(self.train_dataset)
|
else DistributedSampler(self.train_dataset)
|
||||||
)
|
)
|
||||||
data_loader = DataLoader(
|
|
||||||
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
|
"""
|
||||||
|
Returns the training :class:`~torch.utils.data.DataLoader`.
|
||||||
|
|
||||||
|
Will use no sampler if :obj:`self.train_dataset` is a :obj:`torch.utils.data.IterableDataset`, a random sampler
|
||||||
|
(adapted to distributed training if necessary) otherwise.
|
||||||
|
|
||||||
|
Subclass and override this method if you want to inject some custom behavior.
|
||||||
|
"""
|
||||||
|
if self.train_dataset is None:
|
||||||
|
raise ValueError("Trainer: training requires a train_dataset.")
|
||||||
|
train_sampler = self._get_train_sampler()
|
||||||
|
|
||||||
|
return DataLoader(
|
||||||
self.train_dataset,
|
self.train_dataset,
|
||||||
batch_size=self.args.train_batch_size,
|
batch_size=self.args.train_batch_size,
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
@ -250,69 +259,66 @@ class Trainer:
|
|||||||
drop_last=self.args.dataloader_drop_last,
|
drop_last=self.args.dataloader_drop_last,
|
||||||
)
|
)
|
||||||
|
|
||||||
return data_loader
|
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||||
|
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
||||||
|
return None
|
||||||
|
elif is_torch_tpu_available():
|
||||||
|
return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
||||||
|
elif self.args.local_rank != -1:
|
||||||
|
return SequentialDistributedSampler(eval_dataset)
|
||||||
|
else:
|
||||||
|
return SequentialSampler(eval_dataset)
|
||||||
|
|
||||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
Returns the evaluation :class:`~torch.utils.data.DataLoader`.
|
Returns the evaluation :class:`~torch.utils.data.DataLoader`.
|
||||||
|
|
||||||
|
Will use no sampler if :obj:`self.eval_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
|
||||||
|
sampler (adapted to distributed training if necessary) otherwise.
|
||||||
|
|
||||||
|
Subclass and override this method if you want to inject some custom behavior.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
eval_dataset (:obj:`Dataset`, `optional`):
|
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
||||||
If provided, will override `self.eval_dataset`.
|
If provided, will override :obj:`self.eval_dataset`.
|
||||||
"""
|
"""
|
||||||
if eval_dataset is None and self.eval_dataset is None:
|
if eval_dataset is None and self.eval_dataset is None:
|
||||||
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
||||||
|
|
||||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
return DataLoader(
|
||||||
sampler = None
|
|
||||||
elif is_torch_tpu_available():
|
|
||||||
sampler = SequentialDistributedSampler(
|
|
||||||
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
|
|
||||||
)
|
|
||||||
elif self.args.local_rank != -1:
|
|
||||||
sampler = SequentialDistributedSampler(eval_dataset)
|
|
||||||
else:
|
|
||||||
sampler = SequentialSampler(eval_dataset)
|
|
||||||
|
|
||||||
data_loader = DataLoader(
|
|
||||||
eval_dataset,
|
eval_dataset,
|
||||||
sampler=sampler,
|
sampler=eval_sampler,
|
||||||
batch_size=self.args.eval_batch_size,
|
batch_size=self.args.eval_batch_size,
|
||||||
collate_fn=self.data_collator,
|
collate_fn=self.data_collator,
|
||||||
drop_last=self.args.dataloader_drop_last,
|
drop_last=self.args.dataloader_drop_last,
|
||||||
)
|
)
|
||||||
|
|
||||||
return data_loader
|
|
||||||
|
|
||||||
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
Returns the test :class:`~torch.utils.data.DataLoader`.
|
Returns the test :class:`~torch.utils.data.DataLoader`.
|
||||||
|
|
||||||
Args:
|
Will use no sampler if :obj:`test_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
|
||||||
test_dataset (obj:`Dataset`): The test dataset to use.
|
sampler (adapted to distributed training if necessary) otherwise.
|
||||||
"""
|
|
||||||
# We use the same batch_size as for eval.
|
|
||||||
if isinstance(self.test_dataset, torch.utils.data.IterableDataset):
|
|
||||||
sampler = None
|
|
||||||
elif is_torch_tpu_available():
|
|
||||||
sampler = SequentialDistributedSampler(
|
|
||||||
test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
|
|
||||||
)
|
|
||||||
elif self.args.local_rank != -1:
|
|
||||||
sampler = SequentialDistributedSampler(test_dataset)
|
|
||||||
else:
|
|
||||||
sampler = SequentialSampler(test_dataset)
|
|
||||||
|
|
||||||
data_loader = DataLoader(
|
Subclass and override this method if you want to inject some custom behavior.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
||||||
|
The test dataset to use.
|
||||||
|
"""
|
||||||
|
test_sampler = self._get_eval_sampler(test_dataset)
|
||||||
|
|
||||||
|
# We use the same batch_size as for eval.
|
||||||
|
return DataLoader(
|
||||||
test_dataset,
|
test_dataset,
|
||||||
sampler=sampler,
|
sampler=test_sampler,
|
||||||
batch_size=self.args.eval_batch_size,
|
batch_size=self.args.eval_batch_size,
|
||||||
collate_fn=self.data_collator,
|
collate_fn=self.data_collator,
|
||||||
drop_last=self.args.dataloader_drop_last,
|
drop_last=self.args.dataloader_drop_last,
|
||||||
)
|
)
|
||||||
return data_loader
|
|
||||||
|
|
||||||
def get_optimizers(
|
def get_optimizers(
|
||||||
self, num_training_steps: int
|
self, num_training_steps: int
|
||||||
@ -321,7 +327,7 @@ class Trainer:
|
|||||||
Setup the optimizer and the learning rate scheduler.
|
Setup the optimizer and the learning rate scheduler.
|
||||||
|
|
||||||
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
||||||
Trainer's init through :obj:`optimizers`, or override this method in a subclass.
|
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
|
||||||
"""
|
"""
|
||||||
if self.optimizers is not None:
|
if self.optimizers is not None:
|
||||||
return self.optimizers
|
return self.optimizers
|
||||||
@ -343,12 +349,12 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
return optimizer, scheduler
|
return optimizer, scheduler
|
||||||
|
|
||||||
def _setup_wandb(self):
|
def setup_wandb(self):
|
||||||
"""
|
"""
|
||||||
Setup the optional Weights & Biases (`wandb`) integration.
|
Setup the optional Weights & Biases (`wandb`) integration.
|
||||||
|
|
||||||
One can override this method to customize the setup if needed. Find more information at https://docs.wandb.com/huggingface
|
One can subclass and override this method to customize the setup if needed. Find more information
|
||||||
You can also override the following environment variables:
|
`here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
|
||||||
|
|
||||||
Environment:
|
Environment:
|
||||||
WANDB_WATCH:
|
WANDB_WATCH:
|
||||||
@ -359,6 +365,13 @@ class Trainer:
|
|||||||
WANDB_DISABLED:
|
WANDB_DISABLED:
|
||||||
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely
|
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely
|
||||||
"""
|
"""
|
||||||
|
if hasattr(self, "_setup_wandb"):
|
||||||
|
warnings.warn(
|
||||||
|
"The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
return self._setup_wandb()
|
||||||
|
|
||||||
if self.is_world_master():
|
if self.is_world_master():
|
||||||
logger.info(
|
logger.info(
|
||||||
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
||||||
@ -372,7 +385,7 @@ class Trainer:
|
|||||||
|
|
||||||
def num_examples(self, dataloader: DataLoader) -> int:
|
def num_examples(self, dataloader: DataLoader) -> int:
|
||||||
"""
|
"""
|
||||||
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its Dataset.
|
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
|
||||||
"""
|
"""
|
||||||
return len(dataloader.dataset)
|
return len(dataloader.dataset)
|
||||||
|
|
||||||
@ -500,7 +513,7 @@ class Trainer:
|
|||||||
steps_trained_in_current_epoch -= 1
|
steps_trained_in_current_epoch -= 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tr_loss += self._training_step(model, inputs, optimizer)
|
tr_loss += self.training_step(model, inputs, optimizer)
|
||||||
|
|
||||||
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
||||||
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
||||||
@ -535,7 +548,7 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
logging_loss = tr_loss
|
logging_loss = tr_loss
|
||||||
|
|
||||||
self._log(logs)
|
self.log(logs)
|
||||||
|
|
||||||
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
|
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
|
||||||
self.evaluate()
|
self.evaluate()
|
||||||
@ -582,7 +595,25 @@ class Trainer:
|
|||||||
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
|
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
|
||||||
return TrainOutput(self.global_step, tr_loss / self.global_step)
|
return TrainOutput(self.global_step, tr_loss / self.global_step)
|
||||||
|
|
||||||
def _log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
|
def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
|
||||||
|
"""
|
||||||
|
Log :obj:`logs` on the various objects watching training.
|
||||||
|
|
||||||
|
Subclass and override this method to inject custom behavior.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logs (:obj:`Dict[str, float]`):
|
||||||
|
The values to log.
|
||||||
|
iterator (:obj:`tqdm`, `optional`):
|
||||||
|
A potential tqdm progress bar to write the logs on.
|
||||||
|
"""
|
||||||
|
if hasattr(self, "_log"):
|
||||||
|
warnings.warn(
|
||||||
|
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
return self._log(logs, iterator=iterator)
|
||||||
|
|
||||||
if self.epoch is not None:
|
if self.epoch is not None:
|
||||||
logs["epoch"] = self.epoch
|
logs["epoch"] = self.epoch
|
||||||
if self.global_step is None:
|
if self.global_step is None:
|
||||||
@ -612,10 +643,13 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
logger.info(output)
|
logger.info(output)
|
||||||
|
|
||||||
def _training_step(
|
def _prepare_inputs(
|
||||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer
|
self, inputs: Dict[str, Union[torch.Tensor, Any]], model: nn.Module
|
||||||
) -> float:
|
) -> Dict[str, Union[torch.Tensor, Any]]:
|
||||||
model.train()
|
"""
|
||||||
|
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
|
||||||
|
handling potential state.
|
||||||
|
"""
|
||||||
for k, v in inputs.items():
|
for k, v in inputs.items():
|
||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
inputs[k] = v.to(self.args.device)
|
inputs[k] = v.to(self.args.device)
|
||||||
@ -625,9 +659,44 @@ class Trainer:
|
|||||||
# Our model outputs do not work with DataParallel, so forcing return tuple.
|
# Our model outputs do not work with DataParallel, so forcing return tuple.
|
||||||
if isinstance(model, nn.DataParallel):
|
if isinstance(model, nn.DataParallel):
|
||||||
inputs["return_tuple"] = True
|
inputs["return_tuple"] = True
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def training_step(
|
||||||
|
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Perform a training step on :obj:`model` using obj:`inputs` and :obj:`optimizer`.
|
||||||
|
|
||||||
|
Subclass and override to inject custom behavior.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (:obj:`nn.Module`):
|
||||||
|
The model to train.
|
||||||
|
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||||
|
The inputs and targets of the model.
|
||||||
|
|
||||||
|
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||||
|
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||||
|
optimizer (:obj:`torch.optim.Optimizer`):
|
||||||
|
The optimizer to use to make a step.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
`float`:
|
||||||
|
The training loss on this batch.
|
||||||
|
"""
|
||||||
|
if hasattr(self, "_training_step"):
|
||||||
|
warnings.warn(
|
||||||
|
"The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
return self._training_step(model, inputs, optimizer)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
inputs = self._prepare_inputs(inputs, model)
|
||||||
|
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
||||||
|
loss = outputs[0]
|
||||||
|
|
||||||
if self.args.past_index >= 0:
|
if self.args.past_index >= 0:
|
||||||
self._past = outputs[self.args.past_index]
|
self._past = outputs[self.args.past_index]
|
||||||
@ -741,6 +810,8 @@ class Trainer:
|
|||||||
The calling script will be responsible for providing a method to compute metrics, as they are
|
The calling script will be responsible for providing a method to compute metrics, as they are
|
||||||
task-dependent (pass it to the init :obj:`compute_metrics` argument).
|
task-dependent (pass it to the init :obj:`compute_metrics` argument).
|
||||||
|
|
||||||
|
You can also subclass and override this method to inject custom behavior.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
eval_dataset (:obj:`Dataset`, `optional`):
|
eval_dataset (:obj:`Dataset`, `optional`):
|
||||||
Pass a dataset if you wish to override :obj:`self.eval_dataset`.
|
Pass a dataset if you wish to override :obj:`self.eval_dataset`.
|
||||||
@ -749,9 +820,9 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
output = self._prediction_loop(eval_dataloader, description="Evaluation")
|
output = self.prediction_loop(eval_dataloader, description="Evaluation")
|
||||||
|
|
||||||
self._log(output.metrics)
|
self.log(output.metrics)
|
||||||
|
|
||||||
if self.args.tpu_metrics_debug or self.args.debug:
|
if self.args.tpu_metrics_debug or self.args.debug:
|
||||||
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
||||||
@ -780,16 +851,22 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
test_dataloader = self.get_test_dataloader(test_dataset)
|
test_dataloader = self.get_test_dataloader(test_dataset)
|
||||||
|
|
||||||
return self._prediction_loop(test_dataloader, description="Prediction")
|
return self.prediction_loop(test_dataloader, description="Prediction")
|
||||||
|
|
||||||
def _prediction_loop(
|
def prediction_loop(
|
||||||
self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
|
self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
|
||||||
) -> PredictionOutput:
|
) -> PredictionOutput:
|
||||||
"""
|
"""
|
||||||
Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
|
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
|
||||||
|
|
||||||
Works both with or without labels.
|
Works both with or without labels.
|
||||||
"""
|
"""
|
||||||
|
if hasattr(self, "_prediction_loop"):
|
||||||
|
warnings.warn(
|
||||||
|
"The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)
|
||||||
|
|
||||||
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
|
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
|
||||||
|
|
||||||
@ -815,40 +892,20 @@ class Trainer:
|
|||||||
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
|
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
|
||||||
|
|
||||||
if self.args.past_index >= 0:
|
if self.args.past_index >= 0:
|
||||||
past = None
|
self._past = None
|
||||||
|
|
||||||
for inputs in tqdm(dataloader, desc=description):
|
for inputs in tqdm(dataloader, desc=description):
|
||||||
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
|
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
|
||||||
|
if loss is not None:
|
||||||
|
eval_losses.append(loss)
|
||||||
|
if logits is not None:
|
||||||
|
preds = logits if preds is None else torch.cat((preds, logits), dim=0)
|
||||||
|
if labels is not None:
|
||||||
|
label_ids = labels if label_ids is None else torch.cat((label_ids, labels), dim=0)
|
||||||
|
|
||||||
for k, v in inputs.items():
|
if self.args.past_index and hasattr(self, "_past"):
|
||||||
if isinstance(v, torch.Tensor):
|
# Clean the state at the end of the evaluation loop
|
||||||
inputs[k] = v.to(self.args.device)
|
delattr(self, "_past")
|
||||||
if self.args.past_index >= 0:
|
|
||||||
inputs["mems"] = past
|
|
||||||
# Our model outputs do not work with DataParallel, so forcing return tuple.
|
|
||||||
if isinstance(model, nn.DataParallel):
|
|
||||||
inputs["return_tuple"] = True
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = model(**inputs)
|
|
||||||
if has_labels:
|
|
||||||
step_eval_loss, logits = outputs[:2]
|
|
||||||
eval_losses += [step_eval_loss.mean().item()]
|
|
||||||
else:
|
|
||||||
logits = outputs[0]
|
|
||||||
if self.args.past_index >= 0:
|
|
||||||
past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
|
|
||||||
|
|
||||||
if not prediction_loss_only:
|
|
||||||
if preds is None:
|
|
||||||
preds = logits.detach()
|
|
||||||
else:
|
|
||||||
preds = torch.cat((preds, logits.detach()), dim=0)
|
|
||||||
if inputs.get("labels") is not None:
|
|
||||||
if label_ids is None:
|
|
||||||
label_ids = inputs["labels"].detach()
|
|
||||||
else:
|
|
||||||
label_ids = torch.cat((label_ids, inputs["labels"].detach()), dim=0)
|
|
||||||
|
|
||||||
if self.args.local_rank != -1:
|
if self.args.local_rank != -1:
|
||||||
# In distributed mode, concatenate all results from all nodes:
|
# In distributed mode, concatenate all results from all nodes:
|
||||||
@ -894,3 +951,49 @@ class Trainer:
|
|||||||
# truncate the dummy elements added by SequentialDistributedSampler
|
# truncate the dummy elements added by SequentialDistributedSampler
|
||||||
output = concat[:num_total_examples]
|
output = concat[:num_total_examples]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def prediction_step(
|
||||||
|
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
|
||||||
|
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
||||||
|
|
||||||
|
Subclass and override to inject custom behavior.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (:obj:`nn.Module`):
|
||||||
|
The model to evaluate.
|
||||||
|
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||||
|
The inputs and targets of the model.
|
||||||
|
|
||||||
|
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||||
|
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||||
|
prediction_loss_only (:obj:`bool`):
|
||||||
|
Whether or not to return the loss only.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
A tuple with the loss, logits and labels (each being optional).
|
||||||
|
"""
|
||||||
|
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
|
||||||
|
|
||||||
|
inputs = self._prepare_inputs(inputs, model)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
if has_labels:
|
||||||
|
loss, logits = outputs[:2]
|
||||||
|
loss = loss.mean().item()
|
||||||
|
else:
|
||||||
|
loss = None
|
||||||
|
logits = outputs[0]
|
||||||
|
if self.args.past_index >= 0:
|
||||||
|
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
|
||||||
|
|
||||||
|
if prediction_loss_only:
|
||||||
|
return (loss, None, None)
|
||||||
|
|
||||||
|
labels = inputs.get("labels")
|
||||||
|
if labels is not None:
|
||||||
|
labels = labels.detach()
|
||||||
|
return (loss, logits.detach(), labels)
|
||||||
|
Loading…
Reference in New Issue
Block a user