mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
added support for gradient checkpointing in ESM models (#26386)
This commit is contained in:
parent
a8531f3bfd
commit
6ce6a5adb9
@ -690,6 +690,7 @@ class EsmPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
config_class = EsmConfig
|
config_class = EsmConfig
|
||||||
base_model_prefix = "esm"
|
base_model_prefix = "esm"
|
||||||
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"]
|
_no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"]
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
||||||
@ -709,6 +710,10 @@ class EsmPreTrainedModel(PreTrainedModel):
|
|||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
module.weight.data.fill_(1.0)
|
module.weight.data.fill_(1.0)
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
|
if isinstance(module, EsmEncoder):
|
||||||
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
|
||||||
ESM_START_DOCSTRING = r"""
|
ESM_START_DOCSTRING = r"""
|
||||||
|
|
||||||
@ -785,8 +790,6 @@ class EsmModel(EsmPreTrainedModel):
|
|||||||
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
supports_gradient_checkpointing = False
|
|
||||||
|
|
||||||
def __init__(self, config, add_pooling_layer=True):
|
def __init__(self, config, add_pooling_layer=True):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -803,10 +806,6 @@ class EsmModel(EsmPreTrainedModel):
|
|||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
|
||||||
if isinstance(module, EsmEncoder):
|
|
||||||
module.gradient_checkpointing = value
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.embeddings.word_embeddings
|
return self.embeddings.word_embeddings
|
||||||
|
|
||||||
|
@ -151,6 +151,24 @@ class EsmModelTester:
|
|||||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||||
|
|
||||||
|
def create_and_check_forward_and_backwards(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
gradient_checkpointing=False,
|
||||||
|
):
|
||||||
|
model = EsmForMaskedLM(config)
|
||||||
|
if gradient_checkpointing:
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
model.to(torch_device)
|
||||||
|
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
|
result.loss.backward()
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(
|
(
|
||||||
@ -219,6 +237,10 @@ class EsmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_esm_gradient_checkpointing(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in ESM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in ESM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
Loading…
Reference in New Issue
Block a user