Fix the bug that Trainer cannot correctly call torch_jit_model_eval (#35722)

Fix the bug that the accelerator.autocast does not pass parameters correctly when calling torch_jit_model_eval (#35706)
This commit is contained in:
 人民艺术家 2025-01-16 22:53:37 +08:00 committed by GitHub
parent 2cbcc5877d
commit 8b78d9d6e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -230,6 +230,7 @@ if is_accelerate_available():
from accelerate import __version__ as accelerate_version from accelerate import __version__ as accelerate_version
from accelerate.state import AcceleratorState from accelerate.state import AcceleratorState
from accelerate.utils import ( from accelerate.utils import (
AutocastKwargs,
DistributedDataParallelKwargs, DistributedDataParallelKwargs,
DistributedType, DistributedType,
load_fsdp_model, load_fsdp_model,
@ -1832,7 +1833,8 @@ class Trainer:
# remove mixed precision hooks from the model # remove mixed precision hooks from the model
if original_forward: if original_forward:
jit_model.forward = original_forward jit_model.forward = original_forward
with self.accelerator.autocast(cache_enabled=False), torch.no_grad(): autocast_handler = AutocastKwargs(cache_enabled=False)
with self.accelerator.autocast(autocast_handler=autocast_handler), torch.no_grad():
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"): if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"):
if isinstance(example_batch, dict): if isinstance(example_batch, dict):
jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False) jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)