mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Different behavior in DistilBERT when using "inputs_embeds" (#21752)
* Different behavior in DistilBERT when using "inputs_embeds" Fixes #21089 * fix failing test
This commit is contained in:
parent
13489248fa
commit
14f33205a7
@ -105,15 +105,22 @@ class Embeddings(nn.Module):
|
||||
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
||||
)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
Parameters:
|
||||
input_ids: torch.tensor(bs, max_seq_length) The token ids to embed.
|
||||
input_ids (torch.Tensor):
|
||||
torch.tensor(bs, max_seq_length) The token ids to embed.
|
||||
input_embeds (*optional*, torch.Tensor):
|
||||
The pre-computed word embeddings. Can only be passed if the input ids are `None`.
|
||||
|
||||
|
||||
Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
|
||||
embeddings)
|
||||
"""
|
||||
seq_length = input_ids.size(1)
|
||||
if input_ids is not None:
|
||||
input_embeds = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
|
||||
|
||||
seq_length = input_embeds.size(1)
|
||||
|
||||
# Setting the position-ids to the registered buffer in constructor, it helps
|
||||
# when tracing the model without passing position-ids, solves
|
||||
@ -124,10 +131,9 @@ class Embeddings(nn.Module):
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
|
||||
|
||||
word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
|
||||
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
|
||||
|
||||
embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim)
|
||||
embeddings = input_embeds + position_embeddings # (bs, max_seq_length, dim)
|
||||
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
|
||||
embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim)
|
||||
return embeddings
|
||||
@ -573,10 +579,10 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
||||
# Prepare head mask if needed
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim)
|
||||
embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim)
|
||||
|
||||
return self.transformer(
|
||||
x=inputs_embeds,
|
||||
x=embeddings,
|
||||
attn_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
|
Loading…
Reference in New Issue
Block a user