mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Register ModelOutput subclasses as supported torch.utils._pytree nodes (#25358)
* Register ModelOutput subclasses as supported torch.utils._pytree nodes Fixes #25357 where DDP with static_graph=True does not sync gradients when calling backward() over tensors contained in ModelOutput subclasses * Add test for torch pytree ModelOutput serialization and deserialization
This commit is contained in:
parent
a23ac36f8c
commit
d4bd33cc9f
@ -248,6 +248,21 @@ class ModelOutput(OrderedDict):
|
|||||||
</Tip>
|
</Tip>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init_subclass__(cls) -> None:
|
||||||
|
"""Register subclasses as pytree nodes.
|
||||||
|
|
||||||
|
This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with
|
||||||
|
`static_graph=True` with modules that output `ModelOutput` subclasses.
|
||||||
|
"""
|
||||||
|
if is_torch_available():
|
||||||
|
import torch.utils._pytree
|
||||||
|
|
||||||
|
torch.utils._pytree._register_pytree_node(
|
||||||
|
cls,
|
||||||
|
torch.utils._pytree._dict_flatten,
|
||||||
|
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
class_fields = fields(self)
|
class_fields = fields(self)
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ import unittest
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from transformers.testing_utils import require_torch
|
||||||
from transformers.utils import ModelOutput
|
from transformers.utils import ModelOutput
|
||||||
|
|
||||||
|
|
||||||
@ -120,3 +121,25 @@ class ModelOutputTester(unittest.TestCase):
|
|||||||
x = ModelOutputTest(a=(30, 30))
|
x = ModelOutputTest(a=(30, 30))
|
||||||
self.assertEqual(list(x.keys()), ["a"])
|
self.assertEqual(list(x.keys()), ["a"])
|
||||||
self.assertEqual(x.a, (30, 30))
|
self.assertEqual(x.a, (30, 30))
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_torch_pytree(self):
|
||||||
|
# ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves)
|
||||||
|
# this is important for DistributedDataParallel gradient synchronization with static_graph=True
|
||||||
|
import torch
|
||||||
|
import torch.utils._pytree
|
||||||
|
|
||||||
|
x = ModelOutputTest(a=1.0, c=2.0)
|
||||||
|
self.assertFalse(torch.utils._pytree._is_leaf(x))
|
||||||
|
|
||||||
|
expected_flat_outs = [1.0, 2.0]
|
||||||
|
expected_tree_spec = torch.utils._pytree.TreeSpec(
|
||||||
|
ModelOutputTest, ["a", "c"], [torch.utils._pytree.LeafSpec(), torch.utils._pytree.LeafSpec()]
|
||||||
|
)
|
||||||
|
|
||||||
|
actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x)
|
||||||
|
self.assertEqual(expected_flat_outs, actual_flat_outs)
|
||||||
|
self.assertEqual(expected_tree_spec, actual_tree_spec)
|
||||||
|
|
||||||
|
unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
|
||||||
|
self.assertEqual(x, unflattened_x)
|
||||||
|
Loading…
Reference in New Issue
Block a user