Different behavior in DistilBERT when using "inputs_embeds" (#21752)

* Different behavior in DistilBERT when using "inputs_embeds"
Fixes #21089

* fix failing test
This commit is contained in:
Arthur 2023-02-24 09:48:07 +01:00 committed by GitHub
parent 13489248fa
commit 14f33205a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -105,15 +105,22 @@ class Embeddings(nn.Module):
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
def forward(self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Parameters:
input_ids: torch.tensor(bs, max_seq_length) The token ids to embed.
input_ids (torch.Tensor):
torch.tensor(bs, max_seq_length) The token ids to embed.
input_embeds (*optional*, torch.Tensor):
The pre-computed word embeddings. Can only be passed if the input ids are `None`.
Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
embeddings)
"""
seq_length = input_ids.size(1)
if input_ids is not None:
input_embeds = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
seq_length = input_embeds.size(1)
# Setting the position-ids to the registered buffer in constructor, it helps
# when tracing the model without passing position-ids, solves
@ -124,10 +131,9 @@ class Embeddings(nn.Module):
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim)
embeddings = input_embeds + position_embeddings # (bs, max_seq_length, dim)
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim)
return embeddings
@ -573,10 +579,10 @@ class DistilBertModel(DistilBertPreTrainedModel):
# Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim)
embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim)
return self.transformer(
x=inputs_embeds,
x=embeddings,
attn_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,