Handle nested dict/lists of tensors as inputs in the Trainer (#13338)

This commit is contained in:
Sylvain Gugger 2021-08-31 06:34:31 -04:00 committed by GitHub
parent 3efcfeab67
commit 4d10474fa5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1727,22 +1727,30 @@ class Trainer:
self.state.log_history.append(output)
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
"""
Prepares one :obj:`data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
"""
if isinstance(data, dict):
return type(data)(**{k: self._prepare_input(v) for k, v in data.items()})
elif isinstance(data, (tuple, list)):
return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, torch.Tensor):
kwargs = dict(device=self.args.device)
if self.deepspeed and data.dtype != torch.int64:
# NLP models inputs are int64 and those get adjusted to the right dtype of the
# embedding. Other models such as wav2vec2's inputs are already float and thus
# may need special handling to match the dtypes of the model
kwargs.update(dict(dtype=self.args.hf_deepspeed_config.dtype()))
return data.to(**kwargs)
return data
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
"""
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
handling potential state.
"""
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
kwargs = dict(device=self.args.device)
if self.deepspeed and inputs[k].dtype != torch.int64:
# NLP models inputs are int64 and those get adjusted to the right dtype of the
# embedding. Other models such as wav2vec2's inputs are already float and thus
# may need special handling to match the dtypes of the model
kwargs.update(dict(dtype=self.args.hf_deepspeed_config.dtype()))
inputs[k] = v.to(**kwargs)
inputs = self._prepare_input(inputs)
if self.args.past_index >= 0 and self._past is not None:
inputs["mems"] = self._past