From fabe17a726bbf6081cfbcc975d8ac451a81f3e2d Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 31 May 2023 15:32:21 -0400 Subject: [PATCH] Skip device placement for past key values in decoder models (#23919) --- src/transformers/modeling_utils.py | 6 +++++- src/transformers/models/bart/modeling_bart.py | 1 + .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 1 + src/transformers/models/blip_2/modeling_blip_2.py | 1 + src/transformers/models/bloom/modeling_bloom.py | 1 + src/transformers/models/bridgetower/modeling_bridgetower.py | 1 + src/transformers/models/codegen/modeling_codegen.py | 1 + src/transformers/models/gpt2/modeling_gpt2.py | 1 + src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 1 + src/transformers/models/gpt_neo/modeling_gpt_neo.py | 1 + src/transformers/models/gpt_neox/modeling_gpt_neox.py | 1 + .../models/gpt_neox_japanese/modeling_gpt_neox_japanese.py | 1 + src/transformers/models/gptj/modeling_gptj.py | 1 + .../models/gptsan_japanese/modeling_gptsan_japanese.py | 1 + src/transformers/models/llama/modeling_llama.py | 1 + 15 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index aa982873664..1c4ce11b523 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1052,6 +1052,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix main_input_name = "input_ids" _auto_class = None _no_split_modules = None + _skip_keys_device_placement = None _keep_in_fp32_modules = None # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing @@ -2887,7 +2888,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Dispatch model with hooks on all devices if necessary if device_map is not None: - dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index) + kwargs = {"device_map": device_map, "offload_dir": offload_folder, "offload_index": offload_index} + if "skip_keys" in inspect.signature(dispatch_model).parameters: + kwargs["skip_keys"] = model._skip_keys_device_placement + dispatch_model(model, **kwargs) if output_loading_info: if loading_info is None: diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index c1da4eb288e..59d99f9aa13 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -509,6 +509,7 @@ class BartPretrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"] _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] + _skip_keys_device_placement = "past_key_values" def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index e4c64e12b55..61972c53f01 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1597,6 +1597,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"] + _skip_keys_device_placement = "past_key_values" def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 381a3fe0ed7..2cb77f44a28 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -286,6 +286,7 @@ class Blip2PreTrainedModel(PreTrainedModel): r"language_model.lm_head.weight", ] _no_split_modules = ["Blip2Attention", "T5Block", "OPTDecoderLayer"] + _skip_keys_device_placement = "past_key_values" _keep_in_fp32_modules = ["wo"] def _init_weights(self, module): diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 5c0d570cbe9..a3d2242f8a6 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -481,6 +481,7 @@ class BloomPreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["BloomBlock"] + _skip_keys_device_placement = "past_key_values" def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 29bd5b581d9..eb95f12c854 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -982,6 +982,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel): base_model_prefix = "bridgetower" supports_gradient_checkpointing = False _no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"] + _skip_keys_device_placement = "past_key_values" def _init_weights(self, module): if isinstance(module, BridgeTowerVisionModel): diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 82a4f95dde1..aff556d2f14 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -315,6 +315,7 @@ class CodeGenPreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["CodeGenBlock"] + _skip_keys_device_placement = "past_key_values" def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 5fc5451141a..00d92f0bb23 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -449,6 +449,7 @@ class GPT2PreTrainedModel(PreTrainedModel): is_parallelizable = True supports_gradient_checkpointing = True _no_split_modules = ["GPT2Block"] + _skip_keys_device_placement = "past_key_values" def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index f6ec24d7773..0bbe7648237 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -372,6 +372,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["GPTBigCodeBlock"] + _skip_keys_device_placement = "past_key_values" def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 02f7d5534bd..f98a73373b4 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -363,6 +363,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["GPTNeoBlock"] + _skip_keys_device_placement = "past_key_values" def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 48809359463..5ae2807608a 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -62,6 +62,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): base_model_prefix = "gpt_neox" supports_gradient_checkpointing = True _no_split_modules = ["GPTNeoXLayer"] + _skip_keys_device_placement = "past_key_values" def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index d18554480bd..aeab9434e51 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -50,6 +50,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): base_model_prefix = "gpt_neox_japanese" supports_gradient_checkpointing = True _no_split_modules = ["GPTNeoXJapaneseLayer"] + _skip_keys_device_placement = "past_key_values" def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 82cb280caba..91b32d1a824 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -340,6 +340,7 @@ class GPTJPreTrainedModel(PreTrainedModel): is_parallelizable = True supports_gradient_checkpointing = True _no_split_modules = ["GPTJBlock"] + _skip_keys_device_placement = "past_key_values" def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py index 4343340a7f7..a0b68543b3d 100644 --- a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py @@ -692,6 +692,7 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel): base_model_prefix = "gptsan_japanese" supports_gradient_checkpointing = False _no_split_modules = ["GPTSanJapaneseBlock"] + _skip_keys_device_placement = "past_key_values" @property def dummy_inputs(self): diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 80cfdfa5f06..346da82d86f 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -342,6 +342,7 @@ class LlamaPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] def _init_weights(self, module):