Apply torchfix to replace deprecated functions: _pytree._register_pytree_node and torch.cpu.amp.autocast (#37372)

fix: apply torchfix
This commit is contained in:
Brayden Zhong 2025-04-09 11:11:18 -04:00 committed by GitHub
parent ad340908e4
commit 71b35387fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 3 deletions

View File

@ -3689,7 +3689,7 @@ class Trainer:
arguments, depending on the situation.
"""
if self.use_cpu_amp:
ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
ctx_manager = torch.amp.autocast("cpu", cache_enabled=cache_enabled, dtype=self.amp_dtype)
else:
ctx_manager = contextlib.nullcontext()

View File

@ -343,14 +343,18 @@ class ModelOutput(OrderedDict):
"""
if is_torch_available():
if version.parse(get_torch_version()) >= version.parse("2.2"):
_torch_pytree.register_pytree_node(
from torch.utils._pytree import register_pytree_node
register_pytree_node(
cls,
_model_output_flatten,
partial(_model_output_unflatten, output_type=cls),
serialized_type_name=f"{cls.__module__}.{cls.__name__}",
)
else:
_torch_pytree._register_pytree_node(
from torch.utils._pytree import register_pytree_node
register_pytree_node(
cls,
_model_output_flatten,
partial(_model_output_unflatten, output_type=cls),