mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
use accelerate autocast in jit eval path, since mix precision logic is… (#24460)
use accelerate autocast in jit eval path, since mix precision logic is in accelerator currently Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
0863436b6c
commit
6fe8d198e3
@ -17,6 +17,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import glob
|
||||
import inspect
|
||||
@ -143,7 +144,6 @@ from .utils import (
|
||||
logging,
|
||||
strtobool,
|
||||
)
|
||||
from .utils.generic import ContextManagers
|
||||
|
||||
|
||||
DEFAULT_CALLBACKS = [DefaultFlowCallback]
|
||||
@ -1265,9 +1265,14 @@ class Trainer:
|
||||
example_batch = next(iter(dataloader))
|
||||
example_batch = self._prepare_inputs(example_batch)
|
||||
try:
|
||||
jit_model = model.eval()
|
||||
with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]):
|
||||
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"):
|
||||
jit_model = copy.copy(model)
|
||||
jit_model.eval()
|
||||
original_forward = jit_model.__dict__.pop("_original_forward", None)
|
||||
# remove mixed precision hooks from the model
|
||||
if original_forward:
|
||||
jit_model.forward = original_forward
|
||||
with self.accelerator.autocast(cache_enabled=False), torch.no_grad():
|
||||
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"):
|
||||
if isinstance(example_batch, dict):
|
||||
jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user