Add ESM contact prediction (#20535)

* Draft addition of new head

* Finish adding contact heads + tests for ESM

* Add TF contact prediction head

* make fixup

* Minor fix to convert_esm.py

* Clean up function names and comments
This commit is contained in:
Matt 2022-12-02 14:03:30 +00:00 committed by GitHub
parent cc3d0e1b01
commit c54646b13d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 168 additions and 1 deletions

View File

@ -284,6 +284,9 @@ def convert_esm_checkpoint_to_pytorch(
model.lm_head.decoder.weight = esm.lm_head.weight
model.lm_head.bias = esm.lm_head.bias
# Contact prediction head
transfer_and_check_weights(esm.contact_head, model.esm.contact_head)
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
if is_folding_model:
# Folding models aren't trained on masked inputs and don't like mask tokens.
@ -351,6 +354,20 @@ def convert_esm_checkpoint_to_pytorch(
if not success:
raise Exception("Something went wRoNg")
if not is_folding_model:
# Let's check contact prediction too
our_output = model.predict_contacts(hf_tokens["input_ids"], hf_tokens["attention_mask"])
their_output = esm.predict_contacts(hf_tokens["input_ids"])
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
success = torch.allclose(our_output, their_output, atol=1e-5)
print("Contact prediction testing:")
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-5
print("Do both models output the same tensors?", "🔥" if success else "💩")
if not success:
raise Exception("Something went wRoNg")
pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)

View File

@ -68,6 +68,23 @@ def gelu(x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def symmetrize(x):
"Make layer symmetric in final two dimensions, used for contact prediction."
return x + x.transpose(-1, -2)
def average_product_correct(x):
"Perform average product correct, used for contact prediction."
a1 = x.sum(-1, keepdims=True)
a2 = x.sum(-2, keepdims=True)
a12 = x.sum((-1, -2), keepdims=True)
avg = a1 * a2
avg.div_(a12) # in-place to reduce memory
normalized = x - avg
return normalized
class RotaryEmbedding(torch.nn.Module):
"""
Rotary position embeddings based on those in
@ -111,6 +128,41 @@ class RotaryEmbedding(torch.nn.Module):
)
class EsmContactPredictionHead(nn.Module):
"""Performs symmetrization, apc, and computes a logistic regression on the output features"""
def __init__(
self,
in_features: int,
bias=True,
eos_idx: int = 2,
):
super().__init__()
self.in_features = in_features
self.eos_idx = eos_idx
self.regression = nn.Linear(in_features, 1, bias)
self.activation = nn.Sigmoid()
def forward(self, tokens, attentions):
# remove eos token attentions
eos_mask = tokens.ne(self.eos_idx).to(attentions)
eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
attentions = attentions * eos_mask[:, None, None, :, :]
attentions = attentions[..., :-1, :-1]
# remove cls token attentions
attentions = attentions[..., 1:, 1:]
batch_size, layers, heads, seqlen, _ = attentions.size()
attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
# features: batch x channels x tokens x tokens (symmetric)
attentions = attentions.to(
self.regression.weight.device
) # attentions always float32, may need to convert to float16
attentions = average_product_correct(symmetrize(attentions))
attentions = attentions.permute(0, 2, 3, 1)
return self.activation(self.regression(attentions).squeeze(3))
class EsmEmbeddings(nn.Module):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
@ -736,7 +788,6 @@ class EsmModel(EsmPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = False
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Esm
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
@ -746,6 +797,10 @@ class EsmModel(EsmPreTrainedModel):
self.pooler = EsmPooler(config) if add_pooling_layer else None
self.contact_head = EsmContactPredictionHead(
in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
)
# Initialize weights and apply final processing
self.post_init()
@ -894,6 +949,17 @@ class EsmModel(EsmPreTrainedModel):
cross_attentions=encoder_outputs.cross_attentions,
)
def predict_contacts(self, tokens, attention_mask):
attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
attns = torch.stack(attns, dim=1) # Matches the original model layout
# In the original model, attentions for padding tokens are completely zeroed out.
# This makes no difference most of the time because the other tokens won't attend to them,
# but it does for the contact prediction task, which takes attentions as input,
# so we have to mimic that here.
attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
return self.contact_head(tokens, attns)
@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
class EsmForMaskedLM(EsmPreTrainedModel):
@ -983,6 +1049,9 @@ class EsmForMaskedLM(EsmPreTrainedModel):
attentions=outputs.attentions,
)
def predict_contacts(self, tokens, attention_mask):
return self.esm.predict_contacts(tokens, attention_mask=attention_mask)
class EsmLMHead(nn.Module):
"""ESM Head for masked language modeling."""

View File

@ -71,6 +71,23 @@ def apply_rotary_pos_emb(x, cos, sin):
return (x * cos) + (rotate_half(x) * sin)
def symmetrize(x):
"Make layer symmetric in final two dimensions, used for contact prediction."
return x + tf.linalg.matrix_transpose(x) # Transposes last two dimensions only
def average_product_correct(x):
"Perform average product correct, used for contact prediction."
a1 = tf.reduce_sum(x, -1, keepdims=True)
a2 = tf.reduce_sum(x, -2, keepdims=True)
a12 = tf.reduce_sum(x, (-1, -2), keepdims=True)
avg = a1 * a2
avg = avg / a12
normalized = x - avg
return normalized
class TFRotaryEmbedding(Layer):
"""
Rotary position embeddings based on those in
@ -115,6 +132,43 @@ class TFRotaryEmbedding(Layer):
)
class TFEsmContactPredictionHead(Layer):
"""Performs symmetrization, apc, and computes a logistic regression on the output features"""
def __init__(
self,
in_features: int,
bias=True,
eos_idx: int = 2,
name=None,
):
super().__init__(name=name)
self.eos_idx = eos_idx
self.in_features = in_features
self.regression = Dense(1, use_bias=bias, activation="sigmoid", name="regression")
def build(self, input_shape):
super().build(input_shape)
with tf.name_scope("regression"):
self.regression.build((None, self.in_features))
def call(self, tokens, attentions):
# remove eos token attentions
eos_mask = tf.cast(tokens != self.eos_idx, attentions.dtype)
eos_mask = tf.expand_dims(eos_mask, 1) * tf.expand_dims(eos_mask, 2)
attentions = attentions * eos_mask[:, None, None, :, :]
attentions = attentions[..., :-1, :-1]
# remove cls token attentions
attentions = attentions[..., 1:, 1:]
batch_size, layers, heads, seqlen, _ = shape_list(attentions)
attentions = tf.reshape(attentions, (batch_size, layers * heads, seqlen, seqlen))
# features: batch x channels x tokens x tokens (symmetric)
attentions = average_product_correct(symmetrize(attentions))
attentions = tf.transpose(attentions, perm=(0, 2, 3, 1))
return tf.squeeze(self.regression(attentions), 3)
class TFEsmEmbeddings(Layer):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
@ -742,6 +796,15 @@ class TFEsmMainLayer(Layer):
self.encoder = TFEsmEncoder(config, name="encoder")
self.pooler = TFEsmPooler(config, name="pooler") if add_pooling_layer else None
self.contact_head = TFEsmContactPredictionHead(
in_features=self.config.num_hidden_layers * self.config.num_attention_heads, bias=True, name="contact_head"
)
def build(self, input_shape):
super().build(input_shape)
with tf.name_scope("contact_head"):
self.contact_head.build(input_shape)
def get_input_embeddings(self):
return self.embeddings.word_embeddings
@ -906,6 +969,18 @@ class TFEsmMainLayer(Layer):
cross_attentions=encoder_outputs.cross_attentions,
)
def predict_contacts(self, tokens, attention_mask):
attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
attns = tf.stack(attns, axis=1) # Matches the original model layout
# In the original model, attentions for padding tokens are completely zeroed out.
# This makes no difference most of the time because the other tokens won't attend to them,
# but it does for the contact prediction task, which takes attentions as input,
# so we have to mimic that here.
attention_mask = tf.cast(attention_mask, attns.dtype)
attns *= attention_mask[:, None, None, None]
attns *= attention_mask[:, None, None, :, None]
return self.contact_head(tokens, attns)
@add_start_docstrings(
"The bare ESM Model transformer outputting raw hidden-states without any specific head on top.",
@ -1011,6 +1086,9 @@ class TFEsmModel(TFEsmPreTrainedModel):
cross_attentions=cross_attns,
)
def predict_contacts(self, tokens, attention_mask):
return self.esm.predict_contacts(tokens, attention_mask)
@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss):
@ -1123,6 +1201,9 @@ class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss):
return self.serving_output(output)
def predict_contacts(self, tokens, attention_mask):
return self.esm.predict_contacts(tokens, attention_mask)
class TFEsmLMHead(Layer):
"""ESM Head for masked language modeling."""