mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Resize embeds with DeepSpeed (#32214)
* fix resize when deepspeed * deepsped uses new embeds * we needed this
This commit is contained in:
parent
fad15fba78
commit
c46edfb823
@ -1980,12 +1980,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if new_num_tokens is None and pad_to_multiple_of is None:
|
||||
return model_embeds
|
||||
|
||||
# Since we are basically resuing the same old embeddings with new weight values, gathering is required
|
||||
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None):
|
||||
vocab_size = model_embeds.weight.shape[0]
|
||||
else:
|
||||
vocab_size = model_embeds.weight.shape[0]
|
||||
|
||||
# Update base model and current model config
|
||||
if hasattr(self.config, "text_config"):
|
||||
self.config.text_config.vocab_size = model_embeds.weight.shape[0]
|
||||
self.config.text_config.vocab_size = vocab_size
|
||||
else:
|
||||
self.config.vocab_size = model_embeds.weight.shape[0]
|
||||
self.vocab_size = model_embeds.weight.shape[0]
|
||||
self.config.vocab_size = vocab_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
# Tie weights again if needed
|
||||
self.tie_weights()
|
||||
@ -2139,7 +2149,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
params = [old_embeddings.weight, new_embeddings.weight]
|
||||
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
||||
old_embeddings.weight.data = new_embeddings.weight.data
|
||||
old_embeddings.weight = new_embeddings.weight
|
||||
old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
|
||||
|
||||
# If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx`
|
||||
|
Loading…
Reference in New Issue
Block a user