mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Add ESM contact prediction (#20535)
* Draft addition of new head * Finish adding contact heads + tests for ESM * Add TF contact prediction head * make fixup * Minor fix to convert_esm.py * Clean up function names and comments
This commit is contained in:
parent
cc3d0e1b01
commit
c54646b13d
@ -284,6 +284,9 @@ def convert_esm_checkpoint_to_pytorch(
|
||||
model.lm_head.decoder.weight = esm.lm_head.weight
|
||||
model.lm_head.bias = esm.lm_head.bias
|
||||
|
||||
# Contact prediction head
|
||||
transfer_and_check_weights(esm.contact_head, model.esm.contact_head)
|
||||
|
||||
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
|
||||
if is_folding_model:
|
||||
# Folding models aren't trained on masked inputs and don't like mask tokens.
|
||||
@ -351,6 +354,20 @@ def convert_esm_checkpoint_to_pytorch(
|
||||
if not success:
|
||||
raise Exception("Something went wRoNg")
|
||||
|
||||
if not is_folding_model:
|
||||
# Let's check contact prediction too
|
||||
our_output = model.predict_contacts(hf_tokens["input_ids"], hf_tokens["attention_mask"])
|
||||
their_output = esm.predict_contacts(hf_tokens["input_ids"])
|
||||
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
|
||||
success = torch.allclose(our_output, their_output, atol=1e-5)
|
||||
|
||||
print("Contact prediction testing:")
|
||||
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-5
|
||||
print("Do both models output the same tensors?", "🔥" if success else "💩")
|
||||
|
||||
if not success:
|
||||
raise Exception("Something went wRoNg")
|
||||
|
||||
pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
|
||||
print(f"Saving model to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
@ -68,6 +68,23 @@ def gelu(x):
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def symmetrize(x):
|
||||
"Make layer symmetric in final two dimensions, used for contact prediction."
|
||||
return x + x.transpose(-1, -2)
|
||||
|
||||
|
||||
def average_product_correct(x):
|
||||
"Perform average product correct, used for contact prediction."
|
||||
a1 = x.sum(-1, keepdims=True)
|
||||
a2 = x.sum(-2, keepdims=True)
|
||||
a12 = x.sum((-1, -2), keepdims=True)
|
||||
|
||||
avg = a1 * a2
|
||||
avg.div_(a12) # in-place to reduce memory
|
||||
normalized = x - avg
|
||||
return normalized
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
"""
|
||||
Rotary position embeddings based on those in
|
||||
@ -111,6 +128,41 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class EsmContactPredictionHead(nn.Module):
|
||||
"""Performs symmetrization, apc, and computes a logistic regression on the output features"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
bias=True,
|
||||
eos_idx: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.eos_idx = eos_idx
|
||||
self.regression = nn.Linear(in_features, 1, bias)
|
||||
self.activation = nn.Sigmoid()
|
||||
|
||||
def forward(self, tokens, attentions):
|
||||
# remove eos token attentions
|
||||
eos_mask = tokens.ne(self.eos_idx).to(attentions)
|
||||
eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
|
||||
attentions = attentions * eos_mask[:, None, None, :, :]
|
||||
attentions = attentions[..., :-1, :-1]
|
||||
# remove cls token attentions
|
||||
attentions = attentions[..., 1:, 1:]
|
||||
batch_size, layers, heads, seqlen, _ = attentions.size()
|
||||
attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
|
||||
|
||||
# features: batch x channels x tokens x tokens (symmetric)
|
||||
attentions = attentions.to(
|
||||
self.regression.weight.device
|
||||
) # attentions always float32, may need to convert to float16
|
||||
attentions = average_product_correct(symmetrize(attentions))
|
||||
attentions = attentions.permute(0, 2, 3, 1)
|
||||
return self.activation(self.regression(attentions).squeeze(3))
|
||||
|
||||
|
||||
class EsmEmbeddings(nn.Module):
|
||||
"""
|
||||
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
||||
@ -736,7 +788,6 @@ class EsmModel(EsmPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
supports_gradient_checkpointing = False
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Esm
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
@ -746,6 +797,10 @@ class EsmModel(EsmPreTrainedModel):
|
||||
|
||||
self.pooler = EsmPooler(config) if add_pooling_layer else None
|
||||
|
||||
self.contact_head = EsmContactPredictionHead(
|
||||
in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
|
||||
)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@ -894,6 +949,17 @@ class EsmModel(EsmPreTrainedModel):
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def predict_contacts(self, tokens, attention_mask):
|
||||
attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
|
||||
attns = torch.stack(attns, dim=1) # Matches the original model layout
|
||||
# In the original model, attentions for padding tokens are completely zeroed out.
|
||||
# This makes no difference most of the time because the other tokens won't attend to them,
|
||||
# but it does for the contact prediction task, which takes attentions as input,
|
||||
# so we have to mimic that here.
|
||||
attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
|
||||
attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
|
||||
return self.contact_head(tokens, attns)
|
||||
|
||||
|
||||
@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
|
||||
class EsmForMaskedLM(EsmPreTrainedModel):
|
||||
@ -983,6 +1049,9 @@ class EsmForMaskedLM(EsmPreTrainedModel):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def predict_contacts(self, tokens, attention_mask):
|
||||
return self.esm.predict_contacts(tokens, attention_mask=attention_mask)
|
||||
|
||||
|
||||
class EsmLMHead(nn.Module):
|
||||
"""ESM Head for masked language modeling."""
|
||||
|
@ -71,6 +71,23 @@ def apply_rotary_pos_emb(x, cos, sin):
|
||||
return (x * cos) + (rotate_half(x) * sin)
|
||||
|
||||
|
||||
def symmetrize(x):
|
||||
"Make layer symmetric in final two dimensions, used for contact prediction."
|
||||
return x + tf.linalg.matrix_transpose(x) # Transposes last two dimensions only
|
||||
|
||||
|
||||
def average_product_correct(x):
|
||||
"Perform average product correct, used for contact prediction."
|
||||
a1 = tf.reduce_sum(x, -1, keepdims=True)
|
||||
a2 = tf.reduce_sum(x, -2, keepdims=True)
|
||||
a12 = tf.reduce_sum(x, (-1, -2), keepdims=True)
|
||||
|
||||
avg = a1 * a2
|
||||
avg = avg / a12
|
||||
normalized = x - avg
|
||||
return normalized
|
||||
|
||||
|
||||
class TFRotaryEmbedding(Layer):
|
||||
"""
|
||||
Rotary position embeddings based on those in
|
||||
@ -115,6 +132,43 @@ class TFRotaryEmbedding(Layer):
|
||||
)
|
||||
|
||||
|
||||
class TFEsmContactPredictionHead(Layer):
|
||||
"""Performs symmetrization, apc, and computes a logistic regression on the output features"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
bias=True,
|
||||
eos_idx: int = 2,
|
||||
name=None,
|
||||
):
|
||||
super().__init__(name=name)
|
||||
self.eos_idx = eos_idx
|
||||
self.in_features = in_features
|
||||
self.regression = Dense(1, use_bias=bias, activation="sigmoid", name="regression")
|
||||
|
||||
def build(self, input_shape):
|
||||
super().build(input_shape)
|
||||
with tf.name_scope("regression"):
|
||||
self.regression.build((None, self.in_features))
|
||||
|
||||
def call(self, tokens, attentions):
|
||||
# remove eos token attentions
|
||||
eos_mask = tf.cast(tokens != self.eos_idx, attentions.dtype)
|
||||
eos_mask = tf.expand_dims(eos_mask, 1) * tf.expand_dims(eos_mask, 2)
|
||||
attentions = attentions * eos_mask[:, None, None, :, :]
|
||||
attentions = attentions[..., :-1, :-1]
|
||||
# remove cls token attentions
|
||||
attentions = attentions[..., 1:, 1:]
|
||||
batch_size, layers, heads, seqlen, _ = shape_list(attentions)
|
||||
attentions = tf.reshape(attentions, (batch_size, layers * heads, seqlen, seqlen))
|
||||
|
||||
# features: batch x channels x tokens x tokens (symmetric)
|
||||
attentions = average_product_correct(symmetrize(attentions))
|
||||
attentions = tf.transpose(attentions, perm=(0, 2, 3, 1))
|
||||
return tf.squeeze(self.regression(attentions), 3)
|
||||
|
||||
|
||||
class TFEsmEmbeddings(Layer):
|
||||
"""
|
||||
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
||||
@ -742,6 +796,15 @@ class TFEsmMainLayer(Layer):
|
||||
self.encoder = TFEsmEncoder(config, name="encoder")
|
||||
self.pooler = TFEsmPooler(config, name="pooler") if add_pooling_layer else None
|
||||
|
||||
self.contact_head = TFEsmContactPredictionHead(
|
||||
in_features=self.config.num_hidden_layers * self.config.num_attention_heads, bias=True, name="contact_head"
|
||||
)
|
||||
|
||||
def build(self, input_shape):
|
||||
super().build(input_shape)
|
||||
with tf.name_scope("contact_head"):
|
||||
self.contact_head.build(input_shape)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
@ -906,6 +969,18 @@ class TFEsmMainLayer(Layer):
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def predict_contacts(self, tokens, attention_mask):
|
||||
attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
|
||||
attns = tf.stack(attns, axis=1) # Matches the original model layout
|
||||
# In the original model, attentions for padding tokens are completely zeroed out.
|
||||
# This makes no difference most of the time because the other tokens won't attend to them,
|
||||
# but it does for the contact prediction task, which takes attentions as input,
|
||||
# so we have to mimic that here.
|
||||
attention_mask = tf.cast(attention_mask, attns.dtype)
|
||||
attns *= attention_mask[:, None, None, None]
|
||||
attns *= attention_mask[:, None, None, :, None]
|
||||
return self.contact_head(tokens, attns)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare ESM Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
@ -1011,6 +1086,9 @@ class TFEsmModel(TFEsmPreTrainedModel):
|
||||
cross_attentions=cross_attns,
|
||||
)
|
||||
|
||||
def predict_contacts(self, tokens, attention_mask):
|
||||
return self.esm.predict_contacts(tokens, attention_mask)
|
||||
|
||||
|
||||
@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
|
||||
class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
@ -1123,6 +1201,9 @@ class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
|
||||
return self.serving_output(output)
|
||||
|
||||
def predict_contacts(self, tokens, attention_mask):
|
||||
return self.esm.predict_contacts(tokens, attention_mask)
|
||||
|
||||
|
||||
class TFEsmLMHead(Layer):
|
||||
"""ESM Head for masked language modeling."""
|
||||
|
Loading…
Reference in New Issue
Block a user