Fix VisualBert Embeddings (#13017)

This commit is contained in:
Gunjan Chhablani 2021-08-12 13:27:34 +05:30 committed by GitHub
parent 53b38d6269
commit c4e1586db8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -123,7 +123,7 @@ class VisualBertEmbeddings(nn.Module):
inputs_embeds = self.word_embeddings(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.input_embeds.device)
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
token_type_embeddings = self.token_type_embeddings(token_type_ids)