mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix ErnieMEmbeddings
device issue (#21726)
* remove .parameters()).device * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
2f2b19ff40
commit
aff87da15b
@ -77,7 +77,7 @@ class ErnieMEmbeddings(nn.Module):
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
if position_ids is None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
ones = torch.ones(input_shape, dtype=torch.int64)
|
||||
ones = torch.ones(input_shape, dtype=torch.int64, device=inputs_embeds.device)
|
||||
seq_length = torch.cumsum(ones, dim=1)
|
||||
position_ids = seq_length - ones
|
||||
|
||||
@ -85,7 +85,6 @@ class ErnieMEmbeddings(nn.Module):
|
||||
position_ids = position_ids + past_key_values_length
|
||||
# to mimic paddlenlp implementation
|
||||
position_ids += 2
|
||||
position_ids = position_ids.to(next(self.position_embeddings.parameters()).device)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings = inputs_embeds + position_embeddings
|
||||
embeddings = self.layer_norm(embeddings)
|
||||
|
Loading…
Reference in New Issue
Block a user