mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
support ONNX export of XDropout in deberta{,_v2} and sew_d (#17502)
* support ONNX export of XDropout in deberta{,_v2} * black * copy to sew_d * add test * isort * use pytest.mark.filterwarnings * review comments
This commit is contained in:
parent
92915ebec2
commit
9d7b70bcd7
@ -185,6 +185,21 @@ class XDropout(torch.autograd.Function):
|
||||
else:
|
||||
return grad_output, None
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value:
|
||||
dropout_p = local_ctx
|
||||
if isinstance(local_ctx, DropoutContext):
|
||||
dropout_p = local_ctx.dropout
|
||||
# StableDropout only calls this function when training.
|
||||
train = True
|
||||
# TODO: We should check if the opset_version being used to export
|
||||
# is > 12 here, but there's no good way to do that. As-is, if the
|
||||
# opset_version < 12, export will fail with a CheckerError.
|
||||
# Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:
|
||||
# if opset_version < 12:
|
||||
# return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
|
||||
return torch.onnx.symbolic_opset12.dropout(g, input, dropout_p, train)
|
||||
|
||||
|
||||
class StableDropout(nn.Module):
|
||||
"""
|
||||
|
@ -191,6 +191,21 @@ class XDropout(torch.autograd.Function):
|
||||
else:
|
||||
return grad_output, None
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value:
|
||||
dropout_p = local_ctx
|
||||
if isinstance(local_ctx, DropoutContext):
|
||||
dropout_p = local_ctx.dropout
|
||||
# StableDropout only calls this function when training.
|
||||
train = True
|
||||
# TODO: We should check if the opset_version being used to export
|
||||
# is > 12 here, but there's no good way to do that. As-is, if the
|
||||
# opset_version < 12, export will fail with a CheckerError.
|
||||
# Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:
|
||||
# if opset_version < 12:
|
||||
# return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
|
||||
return torch.onnx.symbolic_opset12.dropout(g, input, dropout_p, train)
|
||||
|
||||
|
||||
# Copied from transformers.models.deberta.modeling_deberta.StableDropout
|
||||
class StableDropout(nn.Module):
|
||||
|
@ -595,6 +595,21 @@ class XDropout(torch.autograd.Function):
|
||||
else:
|
||||
return grad_output, None
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value:
|
||||
dropout_p = local_ctx
|
||||
if isinstance(local_ctx, DropoutContext):
|
||||
dropout_p = local_ctx.dropout
|
||||
# StableDropout only calls this function when training.
|
||||
train = True
|
||||
# TODO: We should check if the opset_version being used to export
|
||||
# is > 12 here, but there's no good way to do that. As-is, if the
|
||||
# opset_version < 12, export will fail with a CheckerError.
|
||||
# Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:
|
||||
# if opset_version < 12:
|
||||
# return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
|
||||
return torch.onnx.symbolic_opset12.dropout(g, input, dropout_p, train)
|
||||
|
||||
|
||||
# Copied from transformers.models.deberta.modeling_deberta.StableDropout
|
||||
class StableDropout(nn.Module):
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
from unittest import TestCase
|
||||
@ -26,6 +27,11 @@ from transformers.testing_utils import require_onnx, require_rjieba, require_tf,
|
||||
if is_torch_available() or is_tf_available():
|
||||
from transformers.onnx.features import FeaturesManager
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.models.deberta import modeling_deberta
|
||||
|
||||
|
||||
@require_onnx
|
||||
class OnnxUtilsTestCaseV2(TestCase):
|
||||
@ -356,3 +362,40 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
||||
):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
|
||||
class StableDropoutTestCase(TestCase):
|
||||
"""Tests export of StableDropout module."""
|
||||
|
||||
@require_torch
|
||||
@pytest.mark.filterwarnings("ignore:.*Dropout.*:UserWarning:torch.onnx.*") # torch.onnx is spammy.
|
||||
def test_training(self):
|
||||
"""Tests export of StableDropout in training mode."""
|
||||
devnull = open(os.devnull, "wb")
|
||||
# drop_prob must be > 0 for the test to be meaningful
|
||||
sd = modeling_deberta.StableDropout(0.1)
|
||||
# Avoid warnings in training mode
|
||||
do_constant_folding = False
|
||||
# Dropout is a no-op in inference mode
|
||||
training = torch.onnx.TrainingMode.PRESERVE
|
||||
input = (torch.randn(2, 2),)
|
||||
|
||||
torch.onnx.export(
|
||||
sd,
|
||||
input,
|
||||
devnull,
|
||||
opset_version=12, # Minimum supported
|
||||
do_constant_folding=do_constant_folding,
|
||||
training=training,
|
||||
)
|
||||
|
||||
# Expected to fail with opset_version < 12
|
||||
with self.assertRaises(Exception):
|
||||
torch.onnx.export(
|
||||
sd,
|
||||
input,
|
||||
devnull,
|
||||
opset_version=11,
|
||||
do_constant_folding=do_constant_folding,
|
||||
training=training,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user