Fix the deprecation warning of _torch_pytree._register_pytree_node (#27803)

This commit is contained in:
cyyever 2023-12-17 18:13:42 +08:00 committed by GitHub
parent f85a1e82c1
commit e6dcf8abd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -306,7 +306,7 @@ class ModelOutput(OrderedDict):
`static_graph=True` with modules that output `ModelOutput` subclasses.
"""
if is_torch_available():
_torch_pytree._register_pytree_node(
torch_pytree_register_pytree_node(
cls,
_model_output_flatten,
_model_output_unflatten,
@ -438,7 +438,11 @@ if is_torch_available():
output_type, keys = context
return output_type(**dict(zip(keys, values)))
_torch_pytree._register_pytree_node(
if hasattr(_torch_pytree, "register_pytree_node"):
torch_pytree_register_pytree_node = _torch_pytree.register_pytree_node
else:
torch_pytree_register_pytree_node = _torch_pytree._register_pytree_node
torch_pytree_register_pytree_node(
ModelOutput,
_model_output_flatten,
_model_output_unflatten,