From 6ce6a5adb9ffc071ffe97ee7a7736a120af3d22c Mon Sep 17 00:00:00 2001 From: sanjeevk-os <73068589+sanjeevk-os@users.noreply.github.com> Date: Tue, 26 Sep 2023 18:15:53 +1000 Subject: [PATCH] added support for gradient checkpointing in ESM models (#26386) --- src/transformers/models/esm/modeling_esm.py | 11 +++++------ tests/models/esm/test_modeling_esm.py | 22 +++++++++++++++++++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 05693b0c1e1..ac3e1d77ecf 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -690,6 +690,7 @@ class EsmPreTrainedModel(PreTrainedModel): config_class = EsmConfig base_model_prefix = "esm" + supports_gradient_checkpointing = True _no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"] # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights @@ -709,6 +710,10 @@ class EsmPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, EsmEncoder): + module.gradient_checkpointing = value + ESM_START_DOCSTRING = r""" @@ -785,8 +790,6 @@ class EsmModel(EsmPreTrainedModel): `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. """ - supports_gradient_checkpointing = False - def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config @@ -803,10 +806,6 @@ class EsmModel(EsmPreTrainedModel): # Initialize weights and apply final processing self.post_init() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, EsmEncoder): - module.gradient_checkpointing = value - def get_input_embeddings(self): return self.embeddings.word_embeddings diff --git a/tests/models/esm/test_modeling_esm.py b/tests/models/esm/test_modeling_esm.py index 8af7a318ac6..d09326df606 100644 --- a/tests/models/esm/test_modeling_esm.py +++ b/tests/models/esm/test_modeling_esm.py @@ -151,6 +151,24 @@ class EsmModelTester: result = model(input_ids, attention_mask=input_mask, labels=token_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) + def create_and_check_forward_and_backwards( + self, + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + gradient_checkpointing=False, + ): + model = EsmForMaskedLM(config) + if gradient_checkpointing: + model.gradient_checkpointing_enable() + model.to(torch_device) + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + result.loss.backward() + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -219,6 +237,10 @@ class EsmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + def test_esm_gradient_checkpointing(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True) + @slow def test_model_from_pretrained(self): for model_name in ESM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: