mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Register ModelOutput as supported torch pytree nodes (#26618)
* Register ModelOutput as supported torch pytree nodes * Test ModelOutput as supported torch pytree nodes * Update type hints for pytree unflatten functions
This commit is contained in:
parent
ede051f1b8
commit
cc7803c0a6
@ -22,7 +22,7 @@ from collections.abc import MutableMapping
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from dataclasses import fields, is_dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, ContextManager, List, Tuple
|
||||
from typing import Any, ContextManager, Iterable, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -306,12 +306,10 @@ class ModelOutput(OrderedDict):
|
||||
`static_graph=True` with modules that output `ModelOutput` subclasses.
|
||||
"""
|
||||
if is_torch_available():
|
||||
import torch.utils._pytree
|
||||
|
||||
torch.utils._pytree._register_pytree_node(
|
||||
_torch_pytree._register_pytree_node(
|
||||
cls,
|
||||
torch.utils._pytree._dict_flatten,
|
||||
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
|
||||
_model_output_flatten,
|
||||
_model_output_unflatten,
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@ -430,6 +428,23 @@ class ModelOutput(OrderedDict):
|
||||
return tuple(self[k] for k in self.keys())
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch.utils._pytree as _torch_pytree
|
||||
|
||||
def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]:
|
||||
return list(output.values()), (type(output), list(output.keys()))
|
||||
|
||||
def _model_output_unflatten(values: Iterable[Any], context: "_torch_pytree.Context") -> ModelOutput:
|
||||
output_type, keys = context
|
||||
return output_type(**dict(zip(keys, values)))
|
||||
|
||||
_torch_pytree._register_pytree_node(
|
||||
ModelOutput,
|
||||
_model_output_flatten,
|
||||
_model_output_unflatten,
|
||||
)
|
||||
|
||||
|
||||
class ExplicitEnum(str, Enum):
|
||||
"""
|
||||
Enum with more explicit error message for missing values.
|
||||
|
@ -126,22 +126,24 @@ class ModelOutputTester(unittest.TestCase):
|
||||
def test_torch_pytree(self):
|
||||
# ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves)
|
||||
# this is important for DistributedDataParallel gradient synchronization with static_graph=True
|
||||
import torch
|
||||
import torch.utils._pytree
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
x = ModelOutput({"a": 1.0, "c": 2.0})
|
||||
self.assertFalse(pytree._is_leaf(x))
|
||||
|
||||
x = ModelOutputTest(a=1.0, c=2.0)
|
||||
self.assertFalse(torch.utils._pytree._is_leaf(x))
|
||||
self.assertFalse(pytree._is_leaf(x))
|
||||
|
||||
expected_flat_outs = [1.0, 2.0]
|
||||
expected_tree_spec = torch.utils._pytree.TreeSpec(
|
||||
ModelOutputTest, ["a", "c"], [torch.utils._pytree.LeafSpec(), torch.utils._pytree.LeafSpec()]
|
||||
expected_tree_spec = pytree.TreeSpec(
|
||||
ModelOutputTest, (ModelOutputTest, ["a", "c"]), [pytree.LeafSpec(), pytree.LeafSpec()]
|
||||
)
|
||||
|
||||
actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x)
|
||||
actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x)
|
||||
self.assertEqual(expected_flat_outs, actual_flat_outs)
|
||||
self.assertEqual(expected_tree_spec, actual_tree_spec)
|
||||
|
||||
unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
|
||||
unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
|
||||
self.assertEqual(x, unflattened_x)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user