added support for gradient checkpointing in ESM models (#26386)

This commit is contained in:
sanjeevk-os 2023-09-26 18:15:53 +10:00 committed by GitHub
parent a8531f3bfd
commit 6ce6a5adb9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 6 deletions

View File

@ -690,6 +690,7 @@ class EsmPreTrainedModel(PreTrainedModel):
config_class = EsmConfig
base_model_prefix = "esm"
supports_gradient_checkpointing = True
_no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
@ -709,6 +710,10 @@ class EsmPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, EsmEncoder):
module.gradient_checkpointing = value
ESM_START_DOCSTRING = r"""
@ -785,8 +790,6 @@ class EsmModel(EsmPreTrainedModel):
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
"""
supports_gradient_checkpointing = False
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
@ -803,10 +806,6 @@ class EsmModel(EsmPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, EsmEncoder):
module.gradient_checkpointing = value
def get_input_embeddings(self):
return self.embeddings.word_embeddings

View File

@ -151,6 +151,24 @@ class EsmModelTester:
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_forward_and_backwards(
self,
config,
input_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
gradient_checkpointing=False,
):
model = EsmForMaskedLM(config)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
model.to(torch_device)
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
result.loss.backward()
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@ -219,6 +237,10 @@ class EsmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_esm_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
@slow
def test_model_from_pretrained(self):
for model_name in ESM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: