Fix the issue of using only inputs_embeds in convbert model (#21398)

* Fix the input embeds issue with tests

* Fix black and isort issue

* Clean up tests

* Add slow tag to the test introduced

* Incorporate PR feedbacks
This commit is contained in:
raghavanone 2023-02-01 20:17:25 +05:30 committed by GitHub
parent 65b5035a1d
commit 77db257e2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 1 deletions

View File

@ -818,12 +818,12 @@ class ConvBertModel(ConvBertPreTrainedModel):
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:

View File

@ -437,6 +437,17 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device)
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
def test_model_for_input_embeds(self):
batch_size = 2
seq_length = 10
inputs_embeds = torch.rand([batch_size, seq_length, 768])
config = self.model_tester.get_config()
model = ConvBertModel(config=config)
model.to(torch_device)
model.eval()
result = model(inputs_embeds=inputs_embeds)
self.assertEqual(result.last_hidden_state.shape, (batch_size, seq_length, config.hidden_size))
@require_torch
class ConvBertModelIntegrationTest(unittest.TestCase):