Only put tensors on a device (#5223)

* Only put tensors on a device

* Type hint and unpack list comprehension
This commit is contained in:
Sylvain Gugger 2020-06-23 17:30:17 -04:00 committed by GitHub
parent 173528e368
commit 9022ef021a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -7,7 +7,7 @@ import shutil
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@ -570,11 +570,12 @@ class Trainer:
logger.info(output)
def _training_step(
self, model: nn.Module, inputs: Dict[str, torch.Tensor], optimizer: torch.optim.Optimizer
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer
) -> float:
model.train()
for k, v in inputs.items():
inputs[k] = v.to(self.args.device)
if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.args.device)
outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
@ -758,7 +759,8 @@ class Trainer:
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
for k, v in inputs.items():
inputs[k] = v.to(self.args.device)
if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.args.device)
with torch.no_grad():
outputs = model(**inputs)