mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix the deprecation warning of _torch_pytree._register_pytree_node (#27803)
This commit is contained in:
parent
f85a1e82c1
commit
e6dcf8abd6
@ -306,7 +306,7 @@ class ModelOutput(OrderedDict):
|
||||
`static_graph=True` with modules that output `ModelOutput` subclasses.
|
||||
"""
|
||||
if is_torch_available():
|
||||
_torch_pytree._register_pytree_node(
|
||||
torch_pytree_register_pytree_node(
|
||||
cls,
|
||||
_model_output_flatten,
|
||||
_model_output_unflatten,
|
||||
@ -438,7 +438,11 @@ if is_torch_available():
|
||||
output_type, keys = context
|
||||
return output_type(**dict(zip(keys, values)))
|
||||
|
||||
_torch_pytree._register_pytree_node(
|
||||
if hasattr(_torch_pytree, "register_pytree_node"):
|
||||
torch_pytree_register_pytree_node = _torch_pytree.register_pytree_node
|
||||
else:
|
||||
torch_pytree_register_pytree_node = _torch_pytree._register_pytree_node
|
||||
torch_pytree_register_pytree_node(
|
||||
ModelOutput,
|
||||
_model_output_flatten,
|
||||
_model_output_unflatten,
|
||||
|
Loading…
Reference in New Issue
Block a user