mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Handle nested dict/lists of tensors as inputs in the Trainer (#13338)
This commit is contained in:
parent
3efcfeab67
commit
4d10474fa5
@ -1727,22 +1727,30 @@ class Trainer:
|
||||
self.state.log_history.append(output)
|
||||
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
||||
|
||||
def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
|
||||
"""
|
||||
Prepares one :obj:`data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
return type(data)(**{k: self._prepare_input(v) for k, v in data.items()})
|
||||
elif isinstance(data, (tuple, list)):
|
||||
return type(data)(self._prepare_input(v) for v in data)
|
||||
elif isinstance(data, torch.Tensor):
|
||||
kwargs = dict(device=self.args.device)
|
||||
if self.deepspeed and data.dtype != torch.int64:
|
||||
# NLP models inputs are int64 and those get adjusted to the right dtype of the
|
||||
# embedding. Other models such as wav2vec2's inputs are already float and thus
|
||||
# may need special handling to match the dtypes of the model
|
||||
kwargs.update(dict(dtype=self.args.hf_deepspeed_config.dtype()))
|
||||
return data.to(**kwargs)
|
||||
return data
|
||||
|
||||
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
|
||||
"""
|
||||
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
|
||||
handling potential state.
|
||||
"""
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
kwargs = dict(device=self.args.device)
|
||||
if self.deepspeed and inputs[k].dtype != torch.int64:
|
||||
# NLP models inputs are int64 and those get adjusted to the right dtype of the
|
||||
# embedding. Other models such as wav2vec2's inputs are already float and thus
|
||||
# may need special handling to match the dtypes of the model
|
||||
kwargs.update(dict(dtype=self.args.hf_deepspeed_config.dtype()))
|
||||
|
||||
inputs[k] = v.to(**kwargs)
|
||||
|
||||
inputs = self._prepare_input(inputs)
|
||||
if self.args.past_index >= 0 and self._past is not None:
|
||||
inputs["mems"] = self._past
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user