mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
DeepSpeed ZeRO-3 handling when resizing embedding layers (#26259)
* fix failing deepspeed slow tests * fixes
This commit is contained in:
parent
39df4eca73
commit
ffbf989f0d
@ -1550,7 +1550,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
else:
|
||||
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
||||
|
||||
if old_num_tokens == new_num_tokens:
|
||||
if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
|
||||
return old_embeddings
|
||||
|
||||
if not isinstance(old_embeddings, nn.Embedding):
|
||||
@ -1560,40 +1560,34 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
f" {nn.Embedding}."
|
||||
)
|
||||
|
||||
# Build new embeddings
|
||||
|
||||
# When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
|
||||
# because the shape of the new embedding layer is used across various modeling files
|
||||
# as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
|
||||
# to errors when training.
|
||||
new_embeddings = nn.Embedding(
|
||||
new_num_tokens,
|
||||
old_embedding_dim,
|
||||
device=old_embeddings.weight.device,
|
||||
dtype=old_embeddings.weight.dtype,
|
||||
)
|
||||
|
||||
# initialize all new embeddings (in particular added tokens)
|
||||
self._init_weights(new_embeddings)
|
||||
|
||||
# Copy token embeddings from the previous weights
|
||||
|
||||
# numbers of tokens to copy
|
||||
n = min(old_num_tokens, new_num_tokens)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
|
||||
# Build new embeddings
|
||||
new_embeddings = nn.Embedding(
|
||||
new_num_tokens,
|
||||
old_embedding_dim,
|
||||
device=old_embeddings.weight.device,
|
||||
dtype=old_embeddings.weight.dtype,
|
||||
)
|
||||
|
||||
params = [old_embeddings.weight, new_embeddings.weight]
|
||||
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
||||
# initialize all new embeddings (in particular added tokens)
|
||||
self._init_weights(new_embeddings)
|
||||
|
||||
# Copy token embeddings from the previous weights
|
||||
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
|
||||
else:
|
||||
# Build new embeddings
|
||||
new_embeddings = nn.Embedding(
|
||||
new_num_tokens,
|
||||
old_embedding_dim,
|
||||
device=old_embeddings.weight.device,
|
||||
dtype=old_embeddings.weight.dtype,
|
||||
)
|
||||
|
||||
# initialize all new embeddings (in particular added tokens)
|
||||
self._init_weights(new_embeddings)
|
||||
|
||||
# Copy token embeddings from the previous weights
|
||||
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
|
||||
|
||||
return new_embeddings
|
||||
@ -1636,7 +1630,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
|
||||
)
|
||||
|
||||
if old_num_tokens == new_num_tokens:
|
||||
if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
|
||||
return old_lm_head
|
||||
|
||||
if not isinstance(old_lm_head, nn.Linear):
|
||||
@ -1650,51 +1644,50 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
|
||||
has_new_lm_head_bias = old_lm_head.bias is not None
|
||||
|
||||
# When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
|
||||
# because the shape of the new embedding layer is used across various modeling files
|
||||
# as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
|
||||
# to errors when training.
|
||||
new_lm_head = nn.Linear(
|
||||
*new_lm_head_shape,
|
||||
bias=has_new_lm_head_bias,
|
||||
device=old_lm_head.weight.device,
|
||||
dtype=old_lm_head.weight.dtype,
|
||||
)
|
||||
|
||||
# initialize new lm head (in particular added tokens)
|
||||
self._init_weights(new_lm_head)
|
||||
|
||||
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||
|
||||
# XXX: put the long block of code in a wrapper
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
|
||||
new_lm_head = nn.Linear(
|
||||
*new_lm_head_shape,
|
||||
bias=has_new_lm_head_bias,
|
||||
device=old_lm_head.weight.device,
|
||||
dtype=old_lm_head.weight.dtype,
|
||||
)
|
||||
params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
|
||||
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
||||
self._init_weights(new_lm_head)
|
||||
# Copy old lm head weights to new lm head
|
||||
if not transposed:
|
||||
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
|
||||
else:
|
||||
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
|
||||
|
||||
# Copy bias weights to new lm head
|
||||
if has_new_lm_head_bias:
|
||||
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
|
||||
self._copy_lm_head_original_to_resized(
|
||||
new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
|
||||
)
|
||||
else:
|
||||
new_lm_head = nn.Linear(
|
||||
*new_lm_head_shape,
|
||||
bias=has_new_lm_head_bias,
|
||||
device=old_lm_head.weight.device,
|
||||
dtype=old_lm_head.weight.dtype,
|
||||
self._copy_lm_head_original_to_resized(
|
||||
new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
|
||||
)
|
||||
self._init_weights(new_lm_head)
|
||||
# Copy old lm head weights to new lm head
|
||||
if not transposed:
|
||||
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
|
||||
else:
|
||||
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
|
||||
|
||||
# Copy bias weights to new lm head
|
||||
if has_new_lm_head_bias:
|
||||
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
|
||||
|
||||
return new_lm_head
|
||||
|
||||
def _copy_lm_head_original_to_resized(
|
||||
self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
|
||||
):
|
||||
# Copy old lm head weights to new lm head
|
||||
if not transposed:
|
||||
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
|
||||
else:
|
||||
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
|
||||
|
||||
# Copy bias weights to new lm head
|
||||
if has_new_lm_head_bias:
|
||||
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
|
||||
|
||||
def resize_position_embeddings(self, new_num_position_embeddings: int):
|
||||
raise NotImplementedError(
|
||||
f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
|
||||
|
Loading…
Reference in New Issue
Block a user