mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Only put tensors on a device (#5223)
* Only put tensors on a device * Type hint and unpack list comprehension
This commit is contained in:
parent
173528e368
commit
9022ef021a
@ -7,7 +7,7 @@ import shutil
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -570,11 +570,12 @@ class Trainer:
|
||||
logger.info(output)
|
||||
|
||||
def _training_step(
|
||||
self, model: nn.Module, inputs: Dict[str, torch.Tensor], optimizer: torch.optim.Optimizer
|
||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer
|
||||
) -> float:
|
||||
model.train()
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.to(self.args.device)
|
||||
if isinstance(v, torch.Tensor):
|
||||
inputs[k] = v.to(self.args.device)
|
||||
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
@ -758,7 +759,8 @@ class Trainer:
|
||||
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
|
||||
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.to(self.args.device)
|
||||
if isinstance(v, torch.Tensor):
|
||||
inputs[k] = v.to(self.args.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
Loading…
Reference in New Issue
Block a user