mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Apply torchfix to replace deprecated functions: _pytree._register_pytree_node
and torch.cpu.amp.autocast
(#37372)
fix: apply torchfix
This commit is contained in:
parent
ad340908e4
commit
71b35387fd
@ -3689,7 +3689,7 @@ class Trainer:
|
||||
arguments, depending on the situation.
|
||||
"""
|
||||
if self.use_cpu_amp:
|
||||
ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
|
||||
ctx_manager = torch.amp.autocast("cpu", cache_enabled=cache_enabled, dtype=self.amp_dtype)
|
||||
else:
|
||||
ctx_manager = contextlib.nullcontext()
|
||||
|
||||
|
@ -343,14 +343,18 @@ class ModelOutput(OrderedDict):
|
||||
"""
|
||||
if is_torch_available():
|
||||
if version.parse(get_torch_version()) >= version.parse("2.2"):
|
||||
_torch_pytree.register_pytree_node(
|
||||
from torch.utils._pytree import register_pytree_node
|
||||
|
||||
register_pytree_node(
|
||||
cls,
|
||||
_model_output_flatten,
|
||||
partial(_model_output_unflatten, output_type=cls),
|
||||
serialized_type_name=f"{cls.__module__}.{cls.__name__}",
|
||||
)
|
||||
else:
|
||||
_torch_pytree._register_pytree_node(
|
||||
from torch.utils._pytree import register_pytree_node
|
||||
|
||||
register_pytree_node(
|
||||
cls,
|
||||
_model_output_flatten,
|
||||
partial(_model_output_unflatten, output_type=cls),
|
||||
|
Loading…
Reference in New Issue
Block a user