Fix forward reference imports in DeBERTa configs (#17800)

This commit is contained in:
Sylvain Gugger 2022-06-21 11:21:06 -04:00 committed by GitHub
parent 27e907386a
commit 7bc88c0511
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 6 deletions

View File

@ -14,14 +14,17 @@
# limitations under the License.
""" DeBERTa model configuration"""
from collections import OrderedDict
from typing import Any, Mapping, Optional, Union
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
from ... import FeatureExtractionMixin, PreTrainedTokenizerBase, TensorType
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
if TYPE_CHECKING:
from ... import FeatureExtractionMixin, PreTrainedTokenizerBase, TensorType
logger = logging.get_logger(__name__)
DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
@ -169,7 +172,7 @@ class DebertaOnnxConfig(OnnxConfig):
seq_length: int = -1,
num_choices: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
framework: Optional["TensorType"] = None,
num_channels: int = 3,
image_width: int = 40,
image_height: int = 40,

View File

@ -14,14 +14,17 @@
# limitations under the License.
""" DeBERTa-v2 model configuration"""
from collections import OrderedDict
from typing import Any, Mapping, Optional, Union
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
from ... import FeatureExtractionMixin, PreTrainedTokenizerBase, TensorType
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
if TYPE_CHECKING:
from ... import FeatureExtractionMixin, PreTrainedTokenizerBase, TensorType
logger = logging.get_logger(__name__)
DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
@ -170,7 +173,7 @@ class DebertaV2OnnxConfig(OnnxConfig):
seq_length: int = -1,
num_choices: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
framework: Optional["TensorType"] = None,
num_channels: int = 3,
image_width: int = 40,
image_height: int = 40,