fix resize_token_embeddings (#11572)

This commit is contained in:
Stas Bekman 2021-05-03 13:12:06 -07:00 committed by GitHub
parent fe82b1bfa0
commit 7c622482e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -682,7 +682,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim).to(self.device)
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim).to(
self.device, dtype=old_embeddings.weight.dtype
)
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)