mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix the bug that Trainer
cannot correctly call torch_jit_model_eval
(#35722)
Fix the bug that the accelerator.autocast does not pass parameters correctly when calling torch_jit_model_eval (#35706)
This commit is contained in:
parent
2cbcc5877d
commit
8b78d9d6e7
@ -230,6 +230,7 @@ if is_accelerate_available():
|
||||
from accelerate import __version__ as accelerate_version
|
||||
from accelerate.state import AcceleratorState
|
||||
from accelerate.utils import (
|
||||
AutocastKwargs,
|
||||
DistributedDataParallelKwargs,
|
||||
DistributedType,
|
||||
load_fsdp_model,
|
||||
@ -1832,7 +1833,8 @@ class Trainer:
|
||||
# remove mixed precision hooks from the model
|
||||
if original_forward:
|
||||
jit_model.forward = original_forward
|
||||
with self.accelerator.autocast(cache_enabled=False), torch.no_grad():
|
||||
autocast_handler = AutocastKwargs(cache_enabled=False)
|
||||
with self.accelerator.autocast(autocast_handler=autocast_handler), torch.no_grad():
|
||||
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"):
|
||||
if isinstance(example_batch, dict):
|
||||
jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
|
||||
|
Loading…
Reference in New Issue
Block a user