mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
parent
a960406722
commit
9c9db751e2
@ -49,6 +49,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- BART
|
- BART
|
||||||
- BEiT
|
- BEiT
|
||||||
- BERT
|
- BERT
|
||||||
|
- BigBird
|
||||||
- Blenderbot
|
- Blenderbot
|
||||||
- BlenderbotSmall
|
- BlenderbotSmall
|
||||||
- CamemBERT
|
- CamemBERT
|
||||||
|
@ -28,7 +28,7 @@ from ...utils import (
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_big_bird": ["BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdConfig"],
|
"configuration_big_bird": ["BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdConfig", "BigBirdOnnxConfig"],
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_sentencepiece_available():
|
if is_sentencepiece_available():
|
||||||
@ -66,7 +66,7 @@ if is_flax_available():
|
|||||||
]
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig
|
from .configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig, BigBirdOnnxConfig
|
||||||
|
|
||||||
if is_sentencepiece_available():
|
if is_sentencepiece_available():
|
||||||
from .tokenization_big_bird import BigBirdTokenizer
|
from .tokenization_big_bird import BigBirdTokenizer
|
||||||
|
@ -13,8 +13,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" BigBird model configuration"""
|
""" BigBird model configuration"""
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@ -160,3 +163,14 @@ class BigBirdConfig(PretrainedConfig):
|
|||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.num_random_blocks = num_random_blocks
|
self.num_random_blocks = num_random_blocks
|
||||||
self.classifier_dropout = classifier_dropout
|
self.classifier_dropout = classifier_dropout
|
||||||
|
|
||||||
|
|
||||||
|
class BigBirdOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
@ -3000,7 +3000,7 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
|
|||||||
# setting lengths logits to `-inf`
|
# setting lengths logits to `-inf`
|
||||||
logits_mask = self.prepare_question_mask(question_lengths, seqlen)
|
logits_mask = self.prepare_question_mask(question_lengths, seqlen)
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = (~logits_mask).long()
|
token_type_ids = torch.ones(logits_mask.size(), dtype=int) - logits_mask
|
||||||
logits_mask = logits_mask
|
logits_mask = logits_mask
|
||||||
logits_mask[:, 0] = False
|
logits_mask[:, 0] = False
|
||||||
logits_mask.unsqueeze_(2)
|
logits_mask.unsqueeze_(2)
|
||||||
@ -3063,5 +3063,5 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
|
|||||||
# q_lengths -> (bz, 1)
|
# q_lengths -> (bz, 1)
|
||||||
mask = torch.arange(0, maxlen).to(q_lengths.device)
|
mask = torch.arange(0, maxlen).to(q_lengths.device)
|
||||||
mask.unsqueeze_(0) # -> (1, maxlen)
|
mask.unsqueeze_(0) # -> (1, maxlen)
|
||||||
mask = mask < q_lengths
|
mask = torch.where(mask < q_lengths, 1, 0)
|
||||||
return mask
|
return mask
|
||||||
|
@ -6,6 +6,7 @@ from ..models.albert import AlbertOnnxConfig
|
|||||||
from ..models.bart import BartOnnxConfig
|
from ..models.bart import BartOnnxConfig
|
||||||
from ..models.beit import BeitOnnxConfig
|
from ..models.beit import BeitOnnxConfig
|
||||||
from ..models.bert import BertOnnxConfig
|
from ..models.bert import BertOnnxConfig
|
||||||
|
from ..models.big_bird import BigBirdOnnxConfig
|
||||||
from ..models.blenderbot import BlenderbotOnnxConfig
|
from ..models.blenderbot import BlenderbotOnnxConfig
|
||||||
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
|
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
|
||||||
from ..models.camembert import CamembertOnnxConfig
|
from ..models.camembert import CamembertOnnxConfig
|
||||||
@ -156,6 +157,15 @@ class FeaturesManager:
|
|||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls=BertOnnxConfig,
|
onnx_config_cls=BertOnnxConfig,
|
||||||
),
|
),
|
||||||
|
"bigbird": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"masked-lm",
|
||||||
|
"causal-lm",
|
||||||
|
"sequence-classification",
|
||||||
|
"token-classification",
|
||||||
|
"question-answering",
|
||||||
|
onnx_config_cls=BigBirdOnnxConfig,
|
||||||
|
),
|
||||||
"ibert": supported_features_mapping(
|
"ibert": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
|
@ -172,6 +172,7 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
|||||||
PYTORCH_EXPORT_MODELS = {
|
PYTORCH_EXPORT_MODELS = {
|
||||||
("albert", "hf-internal-testing/tiny-albert"),
|
("albert", "hf-internal-testing/tiny-albert"),
|
||||||
("bert", "bert-base-cased"),
|
("bert", "bert-base-cased"),
|
||||||
|
("bigbird", "google/bigbird-roberta-base"),
|
||||||
("ibert", "kssteven/ibert-roberta-base"),
|
("ibert", "kssteven/ibert-roberta-base"),
|
||||||
("camembert", "camembert-base"),
|
("camembert", "camembert-base"),
|
||||||
("distilbert", "distilbert-base-cased"),
|
("distilbert", "distilbert-base-cased"),
|
||||||
|
Loading…
Reference in New Issue
Block a user