DeepSpeed ZeRO-3 handling when resizing embedding layers (#26259)

* fix failing deepspeed slow tests

* fixes
This commit is contained in:
Sourab Mangrulkar 2023-09-20 00:34:56 +05:30 committed by GitHub
parent 39df4eca73
commit ffbf989f0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1550,7 +1550,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else:
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
if old_num_tokens == new_num_tokens:
if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
return old_embeddings
if not isinstance(old_embeddings, nn.Embedding):
@ -1560,40 +1560,34 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
f" {nn.Embedding}."
)
# Build new embeddings
# When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
# because the shape of the new embedding layer is used across various modeling files
# as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
# to errors when training.
new_embeddings = nn.Embedding(
new_num_tokens,
old_embedding_dim,
device=old_embeddings.weight.device,
dtype=old_embeddings.weight.dtype,
)
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
# Copy token embeddings from the previous weights
# numbers of tokens to copy
n = min(old_num_tokens, new_num_tokens)
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
# Build new embeddings
new_embeddings = nn.Embedding(
new_num_tokens,
old_embedding_dim,
device=old_embeddings.weight.device,
dtype=old_embeddings.weight.dtype,
)
params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
# Copy token embeddings from the previous weights
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
else:
# Build new embeddings
new_embeddings = nn.Embedding(
new_num_tokens,
old_embedding_dim,
device=old_embeddings.weight.device,
dtype=old_embeddings.weight.dtype,
)
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
# Copy token embeddings from the previous weights
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
return new_embeddings
@ -1636,7 +1630,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
)
if old_num_tokens == new_num_tokens:
if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
return old_lm_head
if not isinstance(old_lm_head, nn.Linear):
@ -1650,51 +1644,50 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
has_new_lm_head_bias = old_lm_head.bias is not None
# When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
# because the shape of the new embedding layer is used across various modeling files
# as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
# to errors when training.
new_lm_head = nn.Linear(
*new_lm_head_shape,
bias=has_new_lm_head_bias,
device=old_lm_head.weight.device,
dtype=old_lm_head.weight.dtype,
)
# initialize new lm head (in particular added tokens)
self._init_weights(new_lm_head)
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
# XXX: put the long block of code in a wrapper
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
new_lm_head = nn.Linear(
*new_lm_head_shape,
bias=has_new_lm_head_bias,
device=old_lm_head.weight.device,
dtype=old_lm_head.weight.dtype,
)
params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
self._init_weights(new_lm_head)
# Copy old lm head weights to new lm head
if not transposed:
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
else:
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
# Copy bias weights to new lm head
if has_new_lm_head_bias:
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
self._copy_lm_head_original_to_resized(
new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
)
else:
new_lm_head = nn.Linear(
*new_lm_head_shape,
bias=has_new_lm_head_bias,
device=old_lm_head.weight.device,
dtype=old_lm_head.weight.dtype,
self._copy_lm_head_original_to_resized(
new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
)
self._init_weights(new_lm_head)
# Copy old lm head weights to new lm head
if not transposed:
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
else:
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
# Copy bias weights to new lm head
if has_new_lm_head_bias:
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
return new_lm_head
def _copy_lm_head_original_to_resized(
self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
):
# Copy old lm head weights to new lm head
if not transposed:
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
else:
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
# Copy bias weights to new lm head
if has_new_lm_head_bias:
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
def resize_position_embeddings(self, new_num_position_embeddings: int):
raise NotImplementedError(
f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "