mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix copies
This commit is contained in:
parent
3f23634a17
commit
27c888db6c
@ -494,6 +494,21 @@ class XSoftmax(torch.autograd.Function):
|
||||
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
|
||||
return inputGrad, None, None
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, self, mask, dim):
|
||||
import torch.onnx.symbolic_helper as sym_help
|
||||
from torch.onnx.symbolic_opset9 import masked_fill, softmax
|
||||
|
||||
mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx["Long"])
|
||||
r_mask = g.op(
|
||||
"Cast",
|
||||
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
|
||||
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
|
||||
)
|
||||
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
|
||||
output = softmax(g, output, dim)
|
||||
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
||||
|
||||
|
||||
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
|
||||
class DropoutContext(object):
|
||||
|
Loading…
Reference in New Issue
Block a user