Fix ErnieMEmbeddings device issue (#21726)

* remove .parameters()).device

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-02-22 10:57:34 +01:00 committed by GitHub
parent 2f2b19ff40
commit aff87da15b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -77,7 +77,7 @@ class ErnieMEmbeddings(nn.Module):
inputs_embeds = self.word_embeddings(input_ids)
if position_ids is None:
input_shape = inputs_embeds.size()[:-1]
ones = torch.ones(input_shape, dtype=torch.int64)
ones = torch.ones(input_shape, dtype=torch.int64, device=inputs_embeds.device)
seq_length = torch.cumsum(ones, dim=1)
position_ids = seq_length - ones
@ -85,7 +85,6 @@ class ErnieMEmbeddings(nn.Module):
position_ids = position_ids + past_key_values_length
# to mimic paddlenlp implementation
position_ids += 2
position_ids = position_ids.to(next(self.position_embeddings.parameters()).device)
position_embeddings = self.position_embeddings(position_ids)
embeddings = inputs_embeds + position_embeddings
embeddings = self.layer_norm(embeddings)