mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Bert, et al] fix early device assignment (#14447)
* fix early device assignment * more models
This commit is contained in:
parent
83ef8bcac2
commit
72a6bf33c0
@ -219,7 +219,7 @@ class AlbertEmbeddings(nn.Module):
|
||||
if version.parse(torch.__version__) > version.parse("1.6.0"):
|
||||
self.register_buffer(
|
||||
"token_type_ids",
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
|
@ -182,7 +182,7 @@ class BertEmbeddings(nn.Module):
|
||||
if version.parse(torch.__version__) > version.parse("1.6.0"):
|
||||
self.register_buffer(
|
||||
"token_type_ids",
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
|
@ -261,7 +261,7 @@ class BigBirdEmbeddings(nn.Module):
|
||||
if version.parse(torch.__version__) > version.parse("1.6.0"):
|
||||
self.register_buffer(
|
||||
"token_type_ids",
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
||||
persistent=False,
|
||||
)
|
||||
# End copy
|
||||
|
@ -202,7 +202,7 @@ class ConvBertEmbeddings(nn.Module):
|
||||
if version.parse(torch.__version__) > version.parse("1.6.0"):
|
||||
self.register_buffer(
|
||||
"token_type_ids",
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
|
@ -172,7 +172,7 @@ class ElectraEmbeddings(nn.Module):
|
||||
if version.parse(torch.__version__) > version.parse("1.6.0"):
|
||||
self.register_buffer(
|
||||
"token_type_ids",
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
|
@ -120,7 +120,7 @@ class FNetEmbeddings(nn.Module):
|
||||
if version.parse(torch.__version__) > version.parse("1.6.0"):
|
||||
self.register_buffer(
|
||||
"token_type_ids",
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
|
@ -89,7 +89,7 @@ class RobertaEmbeddings(nn.Module):
|
||||
if version.parse(torch.__version__) > version.parse("1.6.0"):
|
||||
self.register_buffer(
|
||||
"token_type_ids",
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user