mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00
added support for gradient checkpointing in ESM models (#26386)
This commit is contained in:
parent
a8531f3bfd
commit
6ce6a5adb9
@ -690,6 +690,7 @@ class EsmPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = EsmConfig
|
||||
base_model_prefix = "esm"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"]
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
||||
@ -709,6 +710,10 @@ class EsmPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, EsmEncoder):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
ESM_START_DOCSTRING = r"""
|
||||
|
||||
@ -785,8 +790,6 @@ class EsmModel(EsmPreTrainedModel):
|
||||
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
||||
"""
|
||||
|
||||
supports_gradient_checkpointing = False
|
||||
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
@ -803,10 +806,6 @@ class EsmModel(EsmPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, EsmEncoder):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
|
@ -151,6 +151,24 @@ class EsmModelTester:
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_forward_and_backwards(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
model = EsmForMaskedLM(config)
|
||||
if gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
model.to(torch_device)
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
result.loss.backward()
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@ -219,6 +237,10 @@ class EsmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_esm_gradient_checkpointing(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in ESM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
Loading…
Reference in New Issue
Block a user