mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Skip device placement for past key values in decoder models (#23919)
This commit is contained in:
parent
6affd9cd7c
commit
fabe17a726
@ -1052,6 +1052,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
main_input_name = "input_ids"
|
||||
_auto_class = None
|
||||
_no_split_modules = None
|
||||
_skip_keys_device_placement = None
|
||||
_keep_in_fp32_modules = None
|
||||
|
||||
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
|
||||
@ -2887,7 +2888,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
# Dispatch model with hooks on all devices if necessary
|
||||
if device_map is not None:
|
||||
dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index)
|
||||
kwargs = {"device_map": device_map, "offload_dir": offload_folder, "offload_index": offload_index}
|
||||
if "skip_keys" in inspect.signature(dispatch_model).parameters:
|
||||
kwargs["skip_keys"] = model._skip_keys_device_placement
|
||||
dispatch_model(model, **kwargs)
|
||||
|
||||
if output_loading_info:
|
||||
if loading_info is None:
|
||||
|
@ -509,6 +509,7 @@ class BartPretrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"]
|
||||
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
|
@ -1597,6 +1597,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
|
@ -286,6 +286,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
|
||||
r"language_model.lm_head.weight",
|
||||
]
|
||||
_no_split_modules = ["Blip2Attention", "T5Block", "OPTDecoderLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_keep_in_fp32_modules = ["wo"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -481,6 +481,7 @@ class BloomPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "transformer"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BloomBlock"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
@ -982,6 +982,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "bridgetower"
|
||||
supports_gradient_checkpointing = False
|
||||
_no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, BridgeTowerVisionModel):
|
||||
|
@ -315,6 +315,7 @@ class CodeGenPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "transformer"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["CodeGenBlock"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
@ -449,6 +449,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
||||
is_parallelizable = True
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GPT2Block"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
@ -372,6 +372,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "transformer"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GPTBigCodeBlock"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
@ -363,6 +363,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "transformer"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GPTNeoBlock"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
@ -62,6 +62,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "gpt_neox"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GPTNeoXLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
@ -50,6 +50,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "gpt_neox_japanese"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GPTNeoXJapaneseLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
@ -340,6 +340,7 @@ class GPTJPreTrainedModel(PreTrainedModel):
|
||||
is_parallelizable = True
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GPTJBlock"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
@ -692,6 +692,7 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "gptsan_japanese"
|
||||
supports_gradient_checkpointing = False
|
||||
_no_split_modules = ["GPTSanJapaneseBlock"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
|
@ -342,6 +342,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["LlamaDecoderLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
Loading…
Reference in New Issue
Block a user