Fix the bug that Trainer cannot correctly call torch_jit_model_eval (#35722)

Fix the bug that the accelerator.autocast does not pass parameters correctly when calling torch_jit_model_eval (#35706)
This commit is contained in:
 人民艺术家 2025-01-16 22:53:37 +08:00 committed by GitHub
parent 2cbcc5877d
commit 8b78d9d6e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -230,6 +230,7 @@ if is_accelerate_available():
from accelerate import __version__ as accelerate_version
from accelerate.state import AcceleratorState
from accelerate.utils import (
AutocastKwargs,
DistributedDataParallelKwargs,
DistributedType,
load_fsdp_model,
@ -1832,7 +1833,8 @@ class Trainer:
# remove mixed precision hooks from the model
if original_forward:
jit_model.forward = original_forward
with self.accelerator.autocast(cache_enabled=False), torch.no_grad():
autocast_handler = AutocastKwargs(cache_enabled=False)
with self.accelerator.autocast(autocast_handler=autocast_handler), torch.no_grad():
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"):
if isinstance(example_batch, dict):
jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)