Pin minimum PyTorch version for BLOOM ONNX export (#19046)

This commit is contained in:
lewtun 2022-09-15 15:22:31 +02:00 committed by GitHub
parent 0a42b61ede
commit 9b80a0bc18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -16,6 +16,8 @@
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, List, Mapping, Optional
from packaging import version
from transformers import is_torch_available
@ -154,6 +156,9 @@ class BloomConfig(PretrainedConfig):
class BloomOnnxConfig(OnnxConfigWithPast):
torch_onnx_minimum_version = version.parse("1.12")
def __init__(
self,
config: PretrainedConfig,