Fix copies

This commit is contained in:
Lysandre 2021-10-26 15:48:28 -04:00
parent 3f23634a17
commit 27c888db6c

View File

@ -494,6 +494,21 @@ class XSoftmax(torch.autograd.Function):
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
return inputGrad, None, None
@staticmethod
def symbolic(g, self, mask, dim):
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_opset9 import masked_fill, softmax
mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx["Long"])
r_mask = g.op(
"Cast",
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
)
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
class DropoutContext(object):