Skip device placement for past key values in decoder models (#23919)

This commit is contained in:
Sylvain Gugger 2023-05-31 15:32:21 -04:00 committed by GitHub
parent 6affd9cd7c
commit fabe17a726
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 19 additions and 1 deletions

View File

@ -1052,6 +1052,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
main_input_name = "input_ids"
_auto_class = None
_no_split_modules = None
_skip_keys_device_placement = None
_keep_in_fp32_modules = None
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
@ -2887,7 +2888,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Dispatch model with hooks on all devices if necessary
if device_map is not None:
dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index)
kwargs = {"device_map": device_map, "offload_dir": offload_folder, "offload_index": offload_index}
if "skip_keys" in inspect.signature(dispatch_model).parameters:
kwargs["skip_keys"] = model._skip_keys_device_placement
dispatch_model(model, **kwargs)
if output_loading_info:
if loading_info is None:

View File

@ -509,6 +509,7 @@ class BartPretrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"]
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module):
std = self.config.init_std

View File

@ -1597,6 +1597,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module):
std = self.config.init_std

View File

@ -286,6 +286,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
r"language_model.lm_head.weight",
]
_no_split_modules = ["Blip2Attention", "T5Block", "OPTDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_keep_in_fp32_modules = ["wo"]
def _init_weights(self, module):

View File

@ -481,6 +481,7 @@ class BloomPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["BloomBlock"]
_skip_keys_device_placement = "past_key_values"
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

View File

@ -982,6 +982,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
base_model_prefix = "bridgetower"
supports_gradient_checkpointing = False
_no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module):
if isinstance(module, BridgeTowerVisionModel):

View File

@ -315,6 +315,7 @@ class CodeGenPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["CodeGenBlock"]
_skip_keys_device_placement = "past_key_values"
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

View File

@ -449,6 +449,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
is_parallelizable = True
supports_gradient_checkpointing = True
_no_split_modules = ["GPT2Block"]
_skip_keys_device_placement = "past_key_values"
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

View File

@ -372,6 +372,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["GPTBigCodeBlock"]
_skip_keys_device_placement = "past_key_values"
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

View File

@ -363,6 +363,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["GPTNeoBlock"]
_skip_keys_device_placement = "past_key_values"
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

View File

@ -62,6 +62,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
base_model_prefix = "gpt_neox"
supports_gradient_checkpointing = True
_no_split_modules = ["GPTNeoXLayer"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module):
"""Initialize the weights"""

View File

@ -50,6 +50,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
base_model_prefix = "gpt_neox_japanese"
supports_gradient_checkpointing = True
_no_split_modules = ["GPTNeoXJapaneseLayer"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module):
"""Initialize the weights"""

View File

@ -340,6 +340,7 @@ class GPTJPreTrainedModel(PreTrainedModel):
is_parallelizable = True
supports_gradient_checkpointing = True
_no_split_modules = ["GPTJBlock"]
_skip_keys_device_placement = "past_key_values"
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

View File

@ -692,6 +692,7 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel):
base_model_prefix = "gptsan_japanese"
supports_gradient_checkpointing = False
_no_split_modules = ["GPTSanJapaneseBlock"]
_skip_keys_device_placement = "past_key_values"
@property
def dummy_inputs(self):

View File

@ -342,6 +342,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
def _init_weights(self, module):