mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix the bug that Trainer
cannot correctly call torch_jit_model_eval
(#35722)
Fix the bug that the accelerator.autocast does not pass parameters correctly when calling torch_jit_model_eval (#35706)
This commit is contained in:
parent
2cbcc5877d
commit
8b78d9d6e7
@ -230,6 +230,7 @@ if is_accelerate_available():
|
|||||||
from accelerate import __version__ as accelerate_version
|
from accelerate import __version__ as accelerate_version
|
||||||
from accelerate.state import AcceleratorState
|
from accelerate.state import AcceleratorState
|
||||||
from accelerate.utils import (
|
from accelerate.utils import (
|
||||||
|
AutocastKwargs,
|
||||||
DistributedDataParallelKwargs,
|
DistributedDataParallelKwargs,
|
||||||
DistributedType,
|
DistributedType,
|
||||||
load_fsdp_model,
|
load_fsdp_model,
|
||||||
@ -1832,7 +1833,8 @@ class Trainer:
|
|||||||
# remove mixed precision hooks from the model
|
# remove mixed precision hooks from the model
|
||||||
if original_forward:
|
if original_forward:
|
||||||
jit_model.forward = original_forward
|
jit_model.forward = original_forward
|
||||||
with self.accelerator.autocast(cache_enabled=False), torch.no_grad():
|
autocast_handler = AutocastKwargs(cache_enabled=False)
|
||||||
|
with self.accelerator.autocast(autocast_handler=autocast_handler), torch.no_grad():
|
||||||
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"):
|
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"):
|
||||||
if isinstance(example_batch, dict):
|
if isinstance(example_batch, dict):
|
||||||
jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
|
jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user