mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00

* Change the way tracing happens, enabling dynamic axes out of the box * Update the tests and modeling xlnet * Add the non recoding of leaf modules to avoid recording more values for the methods to record than what will be seen at tracing time (which would otherwise desynchronize the recorded values and the values that need to be given to the proxies during tracing, causing errors). * Comments and making tracing work for gpt-j and xlnet * Refactore things related to num_choices (and batch_size, sequence_length) * Update fx to work on PyTorch 1.10 * Postpone autowrap_function feature usage for later * Add copyrights * Remove unnecessary file * Fix issue with add_new_model_like * Apply suggestions
626 lines
25 KiB
Python
Executable File
626 lines
25 KiB
Python
Executable File
# coding=utf-8
|
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
import tempfile
|
|
import unittest
|
|
|
|
from transformers import BertConfig, is_torch_available
|
|
from transformers.models.auto import get_values
|
|
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
|
|
|
from .test_configuration_common import ConfigTester
|
|
from .test_generation_utils import GenerationTesterMixin
|
|
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import (
|
|
MODEL_FOR_PRETRAINING_MAPPING,
|
|
BertForMaskedLM,
|
|
BertForMultipleChoice,
|
|
BertForNextSentencePrediction,
|
|
BertForPreTraining,
|
|
BertForQuestionAnswering,
|
|
BertForSequenceClassification,
|
|
BertForTokenClassification,
|
|
BertLMHeadModel,
|
|
BertModel,
|
|
)
|
|
from transformers.models.bert.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
|
|
|
|
|
class BertModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=13,
|
|
seq_length=7,
|
|
is_training=True,
|
|
use_input_mask=True,
|
|
use_token_type_ids=True,
|
|
use_labels=True,
|
|
vocab_size=99,
|
|
hidden_size=32,
|
|
num_hidden_layers=5,
|
|
num_attention_heads=4,
|
|
intermediate_size=37,
|
|
hidden_act="gelu",
|
|
hidden_dropout_prob=0.1,
|
|
attention_probs_dropout_prob=0.1,
|
|
max_position_embeddings=512,
|
|
type_vocab_size=16,
|
|
type_sequence_label_size=2,
|
|
initializer_range=0.02,
|
|
num_labels=3,
|
|
num_choices=4,
|
|
scope=None,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length
|
|
self.is_training = is_training
|
|
self.use_input_mask = use_input_mask
|
|
self.use_token_type_ids = use_token_type_ids
|
|
self.use_labels = use_labels
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.intermediate_size = intermediate_size
|
|
self.hidden_act = hidden_act
|
|
self.hidden_dropout_prob = hidden_dropout_prob
|
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.type_vocab_size = type_vocab_size
|
|
self.type_sequence_label_size = type_sequence_label_size
|
|
self.initializer_range = initializer_range
|
|
self.num_labels = num_labels
|
|
self.num_choices = num_choices
|
|
self.scope = scope
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
|
|
input_mask = None
|
|
if self.use_input_mask:
|
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
|
|
|
token_type_ids = None
|
|
if self.use_token_type_ids:
|
|
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
|
|
sequence_labels = None
|
|
token_labels = None
|
|
choice_labels = None
|
|
if self.use_labels:
|
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
|
|
config = self.get_config()
|
|
|
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
|
|
def get_config(self):
|
|
"""
|
|
Returns a tiny configuration by default.
|
|
"""
|
|
return BertConfig(
|
|
vocab_size=self.vocab_size,
|
|
hidden_size=self.hidden_size,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
num_attention_heads=self.num_attention_heads,
|
|
intermediate_size=self.intermediate_size,
|
|
hidden_act=self.hidden_act,
|
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
type_vocab_size=self.type_vocab_size,
|
|
is_decoder=False,
|
|
initializer_range=self.initializer_range,
|
|
)
|
|
|
|
def prepare_config_and_inputs_for_decoder(self):
|
|
(
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
) = self.prepare_config_and_inputs()
|
|
|
|
config.is_decoder = True
|
|
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
|
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
|
|
|
return (
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
)
|
|
|
|
def create_and_check_model(
|
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
model = BertModel(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
|
result = model(input_ids, token_type_ids=token_type_ids)
|
|
result = model(input_ids)
|
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
|
|
|
def create_and_check_model_as_decoder(
|
|
self,
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
):
|
|
config.add_cross_attention = True
|
|
model = BertModel(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(
|
|
input_ids,
|
|
attention_mask=input_mask,
|
|
token_type_ids=token_type_ids,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
)
|
|
result = model(
|
|
input_ids,
|
|
attention_mask=input_mask,
|
|
token_type_ids=token_type_ids,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
)
|
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
|
|
|
def create_and_check_for_causal_lm(
|
|
self,
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
):
|
|
model = BertLMHeadModel(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
|
|
|
def create_and_check_for_masked_lm(
|
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
model = BertForMaskedLM(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
|
|
|
def create_and_check_model_for_causal_lm_as_decoder(
|
|
self,
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
):
|
|
config.add_cross_attention = True
|
|
model = BertLMHeadModel(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(
|
|
input_ids,
|
|
attention_mask=input_mask,
|
|
token_type_ids=token_type_ids,
|
|
labels=token_labels,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
)
|
|
result = model(
|
|
input_ids,
|
|
attention_mask=input_mask,
|
|
token_type_ids=token_type_ids,
|
|
labels=token_labels,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
|
|
|
def create_and_check_decoder_model_past_large_inputs(
|
|
self,
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
):
|
|
config.is_decoder = True
|
|
config.add_cross_attention = True
|
|
model = BertLMHeadModel(config=config).to(torch_device).eval()
|
|
|
|
# first forward pass
|
|
outputs = model(
|
|
input_ids,
|
|
attention_mask=input_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
use_cache=True,
|
|
)
|
|
past_key_values = outputs.past_key_values
|
|
|
|
# create hypothetical multiple next token and extent to next_input_ids
|
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
|
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
|
|
|
# append to next input_ids and
|
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
|
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
|
|
|
output_from_no_past = model(
|
|
next_input_ids,
|
|
attention_mask=next_attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
output_hidden_states=True,
|
|
)["hidden_states"][0]
|
|
output_from_past = model(
|
|
next_tokens,
|
|
attention_mask=next_attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
past_key_values=past_key_values,
|
|
output_hidden_states=True,
|
|
)["hidden_states"][0]
|
|
|
|
# select random slice
|
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
|
|
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
|
|
|
# test that outputs are equal for slice
|
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
|
|
|
def create_and_check_for_next_sequence_prediction(
|
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
model = BertForNextSentencePrediction(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(
|
|
input_ids,
|
|
attention_mask=input_mask,
|
|
token_type_ids=token_type_ids,
|
|
labels=sequence_labels,
|
|
)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))
|
|
|
|
def create_and_check_for_pretraining(
|
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
model = BertForPreTraining(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(
|
|
input_ids,
|
|
attention_mask=input_mask,
|
|
token_type_ids=token_type_ids,
|
|
labels=token_labels,
|
|
next_sentence_label=sequence_labels,
|
|
)
|
|
self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
|
self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, 2))
|
|
|
|
def create_and_check_for_question_answering(
|
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
model = BertForQuestionAnswering(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(
|
|
input_ids,
|
|
attention_mask=input_mask,
|
|
token_type_ids=token_type_ids,
|
|
start_positions=sequence_labels,
|
|
end_positions=sequence_labels,
|
|
)
|
|
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
|
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
|
|
|
def create_and_check_for_sequence_classification(
|
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
config.num_labels = self.num_labels
|
|
model = BertForSequenceClassification(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
|
|
|
def create_and_check_for_token_classification(
|
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
config.num_labels = self.num_labels
|
|
model = BertForTokenClassification(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
|
|
|
def create_and_check_for_multiple_choice(
|
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
config.num_choices = self.num_choices
|
|
model = BertForMultipleChoice(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
|
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
|
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
|
result = model(
|
|
multiple_choice_inputs_ids,
|
|
attention_mask=multiple_choice_input_mask,
|
|
token_type_ids=multiple_choice_token_type_ids,
|
|
labels=choice_labels,
|
|
)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
(
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
) = config_and_inputs
|
|
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|
|
|
all_model_classes = (
|
|
(
|
|
BertModel,
|
|
BertLMHeadModel,
|
|
BertForMaskedLM,
|
|
BertForMultipleChoice,
|
|
BertForNextSentencePrediction,
|
|
BertForPreTraining,
|
|
BertForQuestionAnswering,
|
|
BertForSequenceClassification,
|
|
BertForTokenClassification,
|
|
)
|
|
if is_torch_available()
|
|
else ()
|
|
)
|
|
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
|
|
fx_compatible = True
|
|
|
|
# special case for ForPreTraining model
|
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
|
|
|
if return_labels:
|
|
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
|
inputs_dict["labels"] = torch.zeros(
|
|
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
|
)
|
|
inputs_dict["next_sentence_label"] = torch.zeros(
|
|
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
|
)
|
|
return inputs_dict
|
|
|
|
def setUp(self):
|
|
self.model_tester = BertModelTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
def test_model(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
def test_model_various_embeddings(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
config_and_inputs[0].position_embedding_type = type
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
def test_model_as_decoder(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
|
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
|
|
|
|
def test_model_as_decoder_with_default_input_mask(self):
|
|
# This regression test was failing with PyTorch < 1.3
|
|
(
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
) = self.model_tester.prepare_config_and_inputs_for_decoder()
|
|
|
|
input_mask = None
|
|
|
|
self.model_tester.create_and_check_model_as_decoder(
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
)
|
|
|
|
def test_for_causal_lm(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
|
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
|
|
|
def test_for_masked_lm(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
|
|
|
def test_for_causal_lm_decoder(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
|
self.model_tester.create_and_check_model_for_causal_lm_as_decoder(*config_and_inputs)
|
|
|
|
def test_decoder_model_past_with_large_inputs(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
|
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
|
|
|
def test_for_multiple_choice(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
|
|
|
def test_for_next_sequence_prediction(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_for_next_sequence_prediction(*config_and_inputs)
|
|
|
|
def test_for_pretraining(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
|
|
|
|
def test_for_question_answering(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
|
|
|
def test_for_sequence_classification(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
|
|
|
def test_for_token_classification(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
|
|
|
@slow
|
|
def test_model_from_pretrained(self):
|
|
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
|
model = BertModel.from_pretrained(model_name)
|
|
self.assertIsNotNone(model)
|
|
|
|
@slow
|
|
@require_torch_gpu
|
|
def test_torchscript_device_change(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
for model_class in self.all_model_classes:
|
|
|
|
# BertForMultipleChoice behaves incorrectly in JIT environments.
|
|
if model_class == BertForMultipleChoice:
|
|
return
|
|
|
|
config.torchscript = True
|
|
model = model_class(config=config)
|
|
|
|
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
|
traced_model = torch.jit.trace(
|
|
model, (inputs_dict["input_ids"].to("cpu"), inputs_dict["attention_mask"].to("cpu"))
|
|
)
|
|
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
torch.jit.save(traced_model, os.path.join(tmp, "bert.pt"))
|
|
loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device)
|
|
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
|
|
|
|
|
|
@require_torch
|
|
class BertModelIntegrationTest(unittest.TestCase):
|
|
@slow
|
|
def test_inference_no_head_absolute_embedding(self):
|
|
model = BertModel.from_pretrained("bert-base-uncased")
|
|
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
|
|
attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
|
|
with torch.no_grad():
|
|
output = model(input_ids, attention_mask=attention_mask)[0]
|
|
expected_shape = torch.Size((1, 11, 768))
|
|
self.assertEqual(output.shape, expected_shape)
|
|
expected_slice = torch.tensor([[[0.4249, 0.1008, 0.7531], [0.3771, 0.1188, 0.7467], [0.4152, 0.1098, 0.7108]]])
|
|
|
|
self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))
|
|
|
|
@slow
|
|
def test_inference_no_head_relative_embedding_key(self):
|
|
model = BertModel.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key")
|
|
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
|
|
attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
|
|
with torch.no_grad():
|
|
output = model(input_ids, attention_mask=attention_mask)[0]
|
|
expected_shape = torch.Size((1, 11, 768))
|
|
self.assertEqual(output.shape, expected_shape)
|
|
expected_slice = torch.tensor(
|
|
[[[0.0756, 0.3142, -0.5128], [0.3761, 0.3462, -0.5477], [0.2052, 0.3760, -0.1240]]]
|
|
)
|
|
|
|
self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))
|
|
|
|
@slow
|
|
def test_inference_no_head_relative_embedding_key_query(self):
|
|
model = BertModel.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key-query")
|
|
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
|
|
attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
|
|
with torch.no_grad():
|
|
output = model(input_ids, attention_mask=attention_mask)[0]
|
|
expected_shape = torch.Size((1, 11, 768))
|
|
self.assertEqual(output.shape, expected_shape)
|
|
expected_slice = torch.tensor(
|
|
[[[0.6496, 0.3784, 0.8203], [0.8148, 0.5656, 0.2636], [-0.0681, 0.5597, 0.7045]]]
|
|
)
|
|
|
|
self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))
|