[Bert, et al] fix early device assignment (#14447)

* fix early device assignment

* more models
This commit is contained in:
Stas Bekman 2021-11-18 11:47:49 -08:00 committed by GitHub
parent 83ef8bcac2
commit 72a6bf33c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 7 additions and 7 deletions

View File

@ -219,7 +219,7 @@ class AlbertEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)

View File

@ -182,7 +182,7 @@ class BertEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)

View File

@ -261,7 +261,7 @@ class BigBirdEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)
# End copy

View File

@ -202,7 +202,7 @@ class ConvBertEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)

View File

@ -172,7 +172,7 @@ class ElectraEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)

View File

@ -120,7 +120,7 @@ class FNetEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)

View File

@ -89,7 +89,7 @@ class RobertaEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)