diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index f2ecd055564..6020b9fe70a 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -49,6 +49,7 @@ Ready-made configurations include the following architectures: - BART - BEiT - BERT +- BigBird - Blenderbot - BlenderbotSmall - CamemBERT diff --git a/src/transformers/models/big_bird/__init__.py b/src/transformers/models/big_bird/__init__.py index ec7d2a4c552..40aec27b952 100644 --- a/src/transformers/models/big_bird/__init__.py +++ b/src/transformers/models/big_bird/__init__.py @@ -28,7 +28,7 @@ from ...utils import ( _import_structure = { - "configuration_big_bird": ["BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdConfig"], + "configuration_big_bird": ["BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdConfig", "BigBirdOnnxConfig"], } if is_sentencepiece_available(): @@ -66,7 +66,7 @@ if is_flax_available(): ] if TYPE_CHECKING: - from .configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig + from .configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig, BigBirdOnnxConfig if is_sentencepiece_available(): from .tokenization_big_bird import BigBirdTokenizer diff --git a/src/transformers/models/big_bird/configuration_big_bird.py b/src/transformers/models/big_bird/configuration_big_bird.py index eac6aff79de..15efd9c2c8f 100644 --- a/src/transformers/models/big_bird/configuration_big_bird.py +++ b/src/transformers/models/big_bird/configuration_big_bird.py @@ -13,8 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """ BigBird model configuration""" +from collections import OrderedDict +from typing import Mapping from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -160,3 +163,14 @@ class BigBirdConfig(PretrainedConfig): self.block_size = block_size self.num_random_blocks = num_random_blocks self.classifier_dropout = classifier_dropout + + +class BigBirdOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 47da3ff721a..f255a363297 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -3000,7 +3000,7 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel): # setting lengths logits to `-inf` logits_mask = self.prepare_question_mask(question_lengths, seqlen) if token_type_ids is None: - token_type_ids = (~logits_mask).long() + token_type_ids = torch.ones(logits_mask.size(), dtype=int) - logits_mask logits_mask = logits_mask logits_mask[:, 0] = False logits_mask.unsqueeze_(2) @@ -3063,5 +3063,5 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel): # q_lengths -> (bz, 1) mask = torch.arange(0, maxlen).to(q_lengths.device) mask.unsqueeze_(0) # -> (1, maxlen) - mask = mask < q_lengths + mask = torch.where(mask < q_lengths, 1, 0) return mask diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index cf5e55c521d..3875da445f1 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -6,6 +6,7 @@ from ..models.albert import AlbertOnnxConfig from ..models.bart import BartOnnxConfig from ..models.beit import BeitOnnxConfig from ..models.bert import BertOnnxConfig +from ..models.big_bird import BigBirdOnnxConfig from ..models.blenderbot import BlenderbotOnnxConfig from ..models.blenderbot_small import BlenderbotSmallOnnxConfig from ..models.camembert import CamembertOnnxConfig @@ -156,6 +157,15 @@ class FeaturesManager: "question-answering", onnx_config_cls=BertOnnxConfig, ), + "bigbird": supported_features_mapping( + "default", + "masked-lm", + "causal-lm", + "sequence-classification", + "token-classification", + "question-answering", + onnx_config_cls=BigBirdOnnxConfig, + ), "ibert": supported_features_mapping( "default", "masked-lm", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index ba8d51158ff..1ddaa78ce6c 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -172,6 +172,7 @@ class OnnxConfigWithPastTestCaseV2(TestCase): PYTORCH_EXPORT_MODELS = { ("albert", "hf-internal-testing/tiny-albert"), ("bert", "bert-base-cased"), + ("bigbird", "google/bigbird-roberta-base"), ("ibert", "kssteven/ibert-roberta-base"), ("camembert", "camembert-base"), ("distilbert", "distilbert-base-cased"),