use accelerate autocast in jit eval path, since mix precision logic is… (#24460)

use accelerate autocast in jit eval path, since mix precision logic is in accelerator currently

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2023-06-27 20:33:21 +08:00 committed by GitHub
parent 0863436b6c
commit 6fe8d198e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,6 +17,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
"""
import contextlib
import copy
import functools
import glob
import inspect
@ -143,7 +144,6 @@ from .utils import (
logging,
strtobool,
)
from .utils.generic import ContextManagers
DEFAULT_CALLBACKS = [DefaultFlowCallback]
@ -1265,9 +1265,14 @@ class Trainer:
example_batch = next(iter(dataloader))
example_batch = self._prepare_inputs(example_batch)
try:
jit_model = model.eval()
with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]):
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"):
jit_model = copy.copy(model)
jit_model.eval()
original_forward = jit_model.__dict__.pop("_original_forward", None)
# remove mixed precision hooks from the model
if original_forward:
jit_model.forward = original_forward
with self.accelerator.autocast(cache_enabled=False), torch.no_grad():
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"):
if isinstance(example_batch, dict):
jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
else: