mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Different behavior in DistilBERT when using "inputs_embeds" (#21752)
* Different behavior in DistilBERT when using "inputs_embeds" Fixes #21089 * fix failing test
This commit is contained in:
parent
13489248fa
commit
14f33205a7
@ -105,15 +105,22 @@ class Embeddings(nn.Module):
|
|||||||
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def forward(self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Parameters:
|
Parameters:
|
||||||
input_ids: torch.tensor(bs, max_seq_length) The token ids to embed.
|
input_ids (torch.Tensor):
|
||||||
|
torch.tensor(bs, max_seq_length) The token ids to embed.
|
||||||
|
input_embeds (*optional*, torch.Tensor):
|
||||||
|
The pre-computed word embeddings. Can only be passed if the input ids are `None`.
|
||||||
|
|
||||||
|
|
||||||
Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
|
Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
|
||||||
embeddings)
|
embeddings)
|
||||||
"""
|
"""
|
||||||
seq_length = input_ids.size(1)
|
if input_ids is not None:
|
||||||
|
input_embeds = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
|
||||||
|
|
||||||
|
seq_length = input_embeds.size(1)
|
||||||
|
|
||||||
# Setting the position-ids to the registered buffer in constructor, it helps
|
# Setting the position-ids to the registered buffer in constructor, it helps
|
||||||
# when tracing the model without passing position-ids, solves
|
# when tracing the model without passing position-ids, solves
|
||||||
@ -124,10 +131,9 @@ class Embeddings(nn.Module):
|
|||||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
|
||||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
|
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
|
||||||
|
|
||||||
word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
|
|
||||||
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
|
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
|
||||||
|
|
||||||
embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim)
|
embeddings = input_embeds + position_embeddings # (bs, max_seq_length, dim)
|
||||||
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
|
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
|
||||||
embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim)
|
embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim)
|
||||||
return embeddings
|
return embeddings
|
||||||
@ -573,10 +579,10 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
|||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim)
|
||||||
inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim)
|
|
||||||
return self.transformer(
|
return self.transformer(
|
||||||
x=inputs_embeds,
|
x=embeddings,
|
||||||
attn_mask=attention_mask,
|
attn_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
Loading…
Reference in New Issue
Block a user