raise atol for MT5OnnxConfig (#18560)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-08-10 22:41:58 +02:00 committed by GitHub
parent f62cb8313c
commit 9a9a525be8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -147,9 +147,9 @@ class MT5Config(PretrainedConfig):
return self.num_layers
# Copied from transformers.models.t5.configuration_t5.T5OnnxConfig
class MT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
@property
# Copied from transformers.models.t5.configuration_t5.T5OnnxConfig.inputs
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = {
"input_ids": {0: "batch", 1: "encoder_sequence"},
@ -169,5 +169,10 @@ class MT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
return common_inputs
@property
# Copied from transformers.models.t5.configuration_t5.T5OnnxConfig.default_onnx_opset
def default_onnx_opset(self) -> int:
return 13
@property
def atol_for_validation(self) -> float:
return 5e-4