Add ESMFold (#19977)

* initial commit

* First draft that gets outputs without crashing!

* Add all the ported openfold dependencies

* testing

* Restructure config files for ESMFold

* Debugging to find output discrepancies

* Mainly style

* Make model runnable without extra deps

* Remove utils and merge them to the modeling file

* Use correct gelu and remove some debug prints

* More cleanup

* Update esm docs

* Update conversion script to support ESMFold properly

* Port some top-level changes from ESMFold repo

* Expand EsmFold docstrings

* Make attention_mask optional (default to all 1s)

* Add inference test for ESMFold

* Use config and not n kwargs

* Add modeling output class

* Remove einops

* Remove chunking in ESM FFN

* Update tests for ESMFold

* Quality

* REpo consistency

* Remove tree dependency from ESMFold

* make fixup

* Add an error in case my structure map function breaks later

* Remove needless code

* Stop auto-casting the LM to float16 so CPU tests pass

* Stop auto-casting the LM to float16 so CPU tests pass

* Final test updates

* Split test file

* Copyright and quality

* Unpin PyTorch to see built doc

* Fix config file to_dict() method

* Add some docstrings to the output

* Skip TF checkpoint tests for ESM until we reupload those

* make fixup

* More docstrings

* Unpin to get even with main

* Flag example to write

Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
This commit is contained in:
Matt 2022-11-01 01:32:58 +00:00 committed by GitHub
parent 4c9e0f029e
commit 7f9b7b3f0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 6820 additions and 89 deletions

View File

@ -14,8 +14,8 @@ specific language governing permissions and limitations under the License.
## Overview
This page provides code and pre-trained weights for Transformer protein language models from Meta AI's Fundamental
AI Research Team, providing the state-of-the-art ESM-2, and the previously released ESM-1b and ESM-1v. Transformer
protein language models were introduced in the paper [Biological structure and function emerge from scaling
AI Research Team, providing the state-of-the-art ESMFold and ESM-2, and the previously released ESM-1b and ESM-1v.
Transformer protein language models were introduced in the paper [Biological structure and function emerge from scaling
unsupervised learning to 250 million protein sequences](https://www.pnas.org/content/118/15/e2016239118) by
Alexander Rives, Joshua Meier, Tom Sercu, Siddharth Goyal, Zeming Lin, Jason Liu, Demi Guo, Myle Ott,
C. Lawrence Zitnick, Jerry Ma, and Rob Fergus.
@ -27,6 +27,13 @@ It was released with the paper [Language models of protein sequences at the scal
structure prediction](https://doi.org/10.1101/2022.07.20.500902) by Zeming Lin, Halil Akin, Roshan Rao, Brian Hie,
Zhongkai Zhu, Wenting Lu, Allan dos Santos Costa, Maryam Fazel-Zarandi, Tom Sercu, Sal Candido and Alexander Rives.
Also introduced in this paper was ESMFold. It uses an ESM-2 stem with a head that can predict folded protein
structures with state-of-the-art accuracy. Unlike [AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2),
it relies on the token embeddings from the large pre-trained protein language model stem and does not perform a multiple
sequence alignment (MSA) step at inference time, which means that ESMFold checkpoints are fully "standalone" -
they do not require a database of known protein sequences and structures with associated external query tools
to make predictions, and are much faster as a result.
The abstract from
"Biological structure and function emerge from scaling unsupervised learning to 250
@ -63,17 +70,22 @@ order of magnitude faster than AlphaFold2, enabling exploration of the structura
proteins in practical timescales.*
Tips:
- ESM models are trained with a masked language modeling (MLM) objective.
The original code can be found [here](https://github.com/facebookresearch/esm) and was
was developed by the Fundamental AI Research team at Meta AI.
This model was contributed to huggingface by [jasonliu](https://huggingface.co/jasonliu)
ESM-1b, ESM-1v and ESM-2 were contributed to huggingface by [jasonliu](https://huggingface.co/jasonliu)
and [Matt](https://huggingface.co/Rocketknight1).
ESMFold was contributed to huggingface by [Matt](https://huggingface.co/Rocketknight1) and
[Sylvain](https://huggingface.co/sgugger), with a big thank you to Nikita Smetanin, Roshan Rao and Tom Sercu for their
help throughout the process!
The HuggingFace port of ESMFold uses portions of the [openfold](https://github.com/aqlaboratory/openfold) library.
The `openfold` library is licensed under the Apache License 2.0.
## EsmConfig
[[autodoc]] EsmConfig
@ -108,6 +120,11 @@ and [Matt](https://huggingface.co/Rocketknight1).
[[autodoc]] EsmForTokenClassification
- forward
## EsmForProteinFolding
[[autodoc]] EsmForProteinFolding
- forward
## TFEsmModel
[[autodoc]] TFEsmModel

View File

@ -1265,7 +1265,9 @@ else:
_import_structure["models.esm"].extend(
[
"ESM_PRETRAINED_MODEL_ARCHIVE_LIST",
"EsmFoldPreTrainedModel",
"EsmForMaskedLM",
"EsmForProteinFolding",
"EsmForSequenceClassification",
"EsmForTokenClassification",
"EsmModel",
@ -4144,7 +4146,9 @@ if TYPE_CHECKING:
)
from .models.esm import (
ESM_PRETRAINED_MODEL_ARCHIVE_LIST,
EsmFoldPreTrainedModel,
EsmForMaskedLM,
EsmForProteinFolding,
EsmForSequenceClassification,
EsmForTokenClassification,
EsmModel,

View File

@ -39,6 +39,7 @@ else:
"EsmModel",
"EsmPreTrainedModel",
]
_import_structure["modeling_esmfold"] = ["EsmForProteinFolding", "EsmFoldPreTrainedModel"]
try:
if not is_tf_available():
@ -55,7 +56,6 @@ else:
"TFEsmPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_esm import ESM_PRETRAINED_CONFIG_ARCHIVE_MAP, EsmConfig
from .tokenization_esm import EsmTokenizer
@ -74,6 +74,7 @@ if TYPE_CHECKING:
EsmModel,
EsmPreTrainedModel,
)
from .modeling_esmfold import EsmFoldPreTrainedModel, EsmForProteinFolding
try:
if not is_tf_available():

View File

@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2021 Facebook and The HuggingFace Inc. team. All rights reserved.
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,12 +14,16 @@
# limitations under the License.
""" ESM model configuration"""
from dataclasses import asdict, dataclass
from typing import Optional
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
# TODO Update this
ESM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/esm-1b": "https://huggingface.co/facebook/esm-1b/resolve/main/config.json",
# See all ESM models at https://huggingface.co/models?filter=esm
@ -118,9 +122,12 @@ class EsmConfig(PretrainedConfig):
classifier_dropout=None,
emb_layer_norm_before=None,
token_dropout=False,
is_folding_model=False,
esmfold_config=None,
vocab_list=None,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
@ -138,5 +145,225 @@ class EsmConfig(PretrainedConfig):
self.classifier_dropout = classifier_dropout
self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout
self.mask_token_id = mask_token_id
self.pad_token_id = pad_token_id
self.is_folding_model = is_folding_model
if is_folding_model:
if esmfold_config is None:
logger.info("No esmfold_config supplied for folding model, using default values.")
esmfold_config = EsmFoldConfig()
elif isinstance(esmfold_config, dict):
esmfold_config = EsmFoldConfig(**esmfold_config)
self.esmfold_config = esmfold_config
if vocab_list is None:
logger.warning("No vocab_list supplied for folding model, assuming the ESM-2 vocabulary!")
self.vocab_list = get_default_vocab_list()
else:
self.vocab_list = vocab_list
else:
self.esmfold_config = None
self.vocab_list = None
if self.esmfold_config is not None and getattr(self.esmfold_config, "use_esm_attn_map", False):
raise ValueError("The HuggingFace port of ESMFold does not support use_esm_attn_map at this time!")
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = super().to_dict()
if isinstance(self.esmfold_config, EsmFoldConfig):
output["esmfold_config"] = self.esmfold_config.to_dict()
return output
@dataclass
class EsmFoldConfig:
esm_type: str = None
fp16_esm: bool = True
use_esm_attn_map: bool = False
esm_ablate_pairwise: bool = False
esm_ablate_sequence: bool = False
esm_input_dropout: float = 0
embed_aa: bool = True
bypass_lm: bool = False
lddt_head_hid_dim: int = 128
trunk: "TrunkConfig" = None
def __post_init__(self):
if self.trunk is None:
self.trunk = TrunkConfig()
elif isinstance(self.trunk, dict):
self.trunk = TrunkConfig(**self.trunk)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = asdict(self)
output["trunk"] = self.trunk.to_dict()
return output
@dataclass
class TrunkConfig:
num_blocks: int = 48
sequence_state_dim: int = 1024
pairwise_state_dim: int = 128
sequence_head_width: int = 32
pairwise_head_width: int = 32
position_bins: int = 32
dropout: float = 0
layer_drop: float = 0
cpu_grad_checkpoint: bool = False
max_recycles: int = 4
chunk_size: Optional[int] = 128
structure_module: "StructureModuleConfig" = None
def __post_init__(self):
if self.structure_module is None:
self.structure_module = StructureModuleConfig()
elif isinstance(self.structure_module, dict):
self.structure_module = StructureModuleConfig(**self.structure_module)
if self.max_recycles <= 0:
raise ValueError(f"`max_recycles` should be positive, got {self.max_recycles}.")
if self.sequence_state_dim % self.sequence_state_dim != 0:
raise ValueError(
"`sequence_state_dim` should be a round multiple of `sequence_state_dim`, got"
f" {self.sequence_state_dim} and {self.sequence_state_dim}."
)
if self.pairwise_state_dim % self.pairwise_state_dim != 0:
raise ValueError(
"`pairwise_state_dim` should be a round multiple of `pairwise_state_dim`, got"
f" {self.pairwise_state_dim} and {self.pairwise_state_dim}."
)
sequence_num_heads = self.sequence_state_dim // self.sequence_head_width
pairwise_num_heads = self.pairwise_state_dim // self.pairwise_head_width
if self.sequence_state_dim != sequence_num_heads * self.sequence_head_width:
raise ValueError(
"`sequence_state_dim` should be equal to `sequence_num_heads * sequence_head_width, got"
f" {self.sequence_state_dim} != {sequence_num_heads} * {self.sequence_head_width}."
)
if self.pairwise_state_dim != pairwise_num_heads * self.pairwise_head_width:
raise ValueError(
"`pairwise_state_dim` should be equal to `pairwise_num_heads * pairwise_head_width, got"
f" {self.pairwise_state_dim} != {pairwise_num_heads} * {self.pairwise_head_width}."
)
if self.pairwise_state_dim % 2 != 0:
raise ValueError(f"`pairwise_state_dim` should be even, got {self.pairwise_state_dim}.")
if self.dropout >= 0.4:
raise ValueError(f"`dropout` should not be greater than 0.4, got {self.dropout}.")
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = asdict(self)
output["structure_module"] = self.structure_module.to_dict()
return output
@dataclass
class StructureModuleConfig:
"""
Args:
sequence_dim:
Single representation channel dimension
pairwise_dim:
Pair representation channel dimension
ipa_dim:
IPA hidden channel dimension
resnet_dim:
Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
num_heads_ipa:
Number of IPA heads
num_qk_points:
Number of query/key points to generate during IPA
num_v_points:
Number of value points to generate during IPA
dropout_rate:
Dropout rate used throughout the layer
num_blocks:
Number of structure module blocks
num_transition_layers:
Number of layers in the single representation transition (Alg. 23 lines 8-9)
num_resnet_blocks:
Number of blocks in the angle resnet
num_angles:
Number of angles to generate in the angle resnet
trans_scale_factor:
Scale of single representation transition hidden dimension
epsilon:
Small number used in angle resnet normalization
inf:
Large number used for attention masking
"""
sequence_dim: int = 384
pairwise_dim: int = 128
ipa_dim: int = 16
resnet_dim: int = 128
num_heads_ipa: int = 12
num_qk_points: int = 4
num_v_points: int = 8
dropout_rate: float = 0.1
num_blocks: int = 8
num_transition_layers: int = 1
num_resnet_blocks: int = 2
num_angles: int = 7
trans_scale_factor: int = 10
epsilon: float = 1e-8
inf: float = 1e5
def to_dict(self):
return asdict(self)
def get_default_vocab_list():
return (
"<cls>",
"<pad>",
"<eos>",
"<unk>",
"L",
"A",
"G",
"V",
"S",
"E",
"R",
"T",
"I",
"D",
"P",
"K",
"Q",
"N",
"F",
"Y",
"M",
"H",
"W",
"C",
"X",
"B",
"U",
"Z",
"O",
".",
"-",
"<null_1>",
"<mask>",
)

View File

@ -23,7 +23,8 @@ from tempfile import TemporaryDirectory
import torch
import esm as esm_module
from transformers.models.esm.configuration_esm import EsmConfig
from esm.esmfold.v1.pretrained import esmfold_v1
from transformers.models.esm.configuration_esm import EsmConfig, EsmFoldConfig
from transformers.models.esm.modeling_esm import (
EsmForMaskedLM,
EsmForSequenceClassification,
@ -33,6 +34,7 @@ from transformers.models.esm.modeling_esm import (
EsmSelfAttention,
EsmSelfOutput,
)
from transformers.models.esm.modeling_esmfold import EsmForProteinFolding
from transformers.models.esm.tokenization_esm import EsmTokenizer
from transformers.utils import logging
@ -60,19 +62,51 @@ MODEL_MAPPING = {
"esm2_t30_150M_UR50D": esm_module.pretrained.esm2_t30_150M_UR50D,
"esm2_t12_35M_UR50D": esm_module.pretrained.esm2_t12_35M_UR50D,
"esm2_t6_8M_UR50D": esm_module.pretrained.esm2_t6_8M_UR50D,
"esmfold_v1": esmfold_v1,
}
def transfer_and_check_weights(original_module, our_module):
status = our_module.load_state_dict(original_module.state_dict())
if status.missing_keys:
raise ValueError(f"Missing keys: {status.missing_keys}")
if status.unexpected_keys:
raise ValueError(f"Unexpected keys: {status.unexpected_keys}")
def convert_esm_checkpoint_to_pytorch(
model: str, pytorch_dump_folder_path: str, classification_head: bool, push_to_repo: str, auth_token: str
):
"""
Copy/paste/tweak esm's weights to our BERT structure.
"""
esm, alphabet = MODEL_MAPPING[model]()
if model.startswith("esmfold"):
esm = MODEL_MAPPING[model]()
alphabet = esm.esm.alphabet
else:
esm, alphabet = MODEL_MAPPING[model]()
esm.eval() # disable dropout
esm_sent_encoder = esm
if hasattr(esm, "args"):
if model.startswith("esmfold"):
embed_dim = esm.esm.embed_dim
num_layers = esm.esm.num_layers
num_attention_heads = esm.esm.attention_heads
intermediate_size = 4 * embed_dim
token_dropout = esm.esm.token_dropout
emb_layer_norm_before = False # This code path does not exist in ESM-2
position_embedding_type = "rotary"
is_folding_model = True
esmfold_config = EsmFoldConfig()
for key, val in esm.cfg.items():
if hasattr(esmfold_config, key) and key != "trunk":
setattr(esmfold_config, key, val)
for key, val in esm.cfg.trunk.items():
if hasattr(esmfold_config.trunk, key) and key != "structure_module":
setattr(esmfold_config.trunk, key, val)
for key, val in esm.cfg.trunk.structure_module.items():
if hasattr(esmfold_config.trunk.structure_module, key):
setattr(esmfold_config.trunk.structure_module, key, val)
elif hasattr(esm, "args"):
# Indicates an ESM-1b or ESM-1v model
embed_dim = esm.args.embed_dim
num_layers = esm.args.layers
@ -81,6 +115,8 @@ def convert_esm_checkpoint_to_pytorch(
token_dropout = esm.args.token_dropout
emb_layer_norm_before = True if esm.emb_layer_norm_before else False
position_embedding_type = "absolute"
is_folding_model = False
esmfold_config = None
else:
# Indicates an ESM-2 model
embed_dim = esm.embed_dim
@ -90,9 +126,18 @@ def convert_esm_checkpoint_to_pytorch(
token_dropout = esm.token_dropout
emb_layer_norm_before = False # This code path does not exist in ESM-2
position_embedding_type = "rotary"
is_folding_model = False
esmfold_config = None
vocab_list = tuple(alphabet.all_toks)
if is_folding_model:
original_esm_model = esm.esm
else:
original_esm_model = esm
config = EsmConfig(
vocab_size=esm_sent_encoder.embed_tokens.num_embeddings,
vocab_size=original_esm_model.embed_tokens.num_embeddings,
mask_token_id=alphabet.mask_idx,
hidden_size=embed_dim,
num_hidden_layers=num_layers,
@ -102,36 +147,45 @@ def convert_esm_checkpoint_to_pytorch(
layer_norm_eps=1e-5, # PyTorch default used in fairseq
attention_probs_dropout_prob=0.0,
hidden_dropout_prob=0.0,
pad_token_id=esm.padding_idx,
pad_token_id=alphabet.padding_idx,
emb_layer_norm_before=emb_layer_norm_before,
token_dropout=token_dropout,
position_embedding_type=position_embedding_type,
is_folding_model=is_folding_model,
esmfold_config=esmfold_config,
vocab_list=vocab_list,
)
if classification_head:
config.num_labels = esm.classification_heads["mnli"].out_proj.weight.shape[0]
print("Our BERT config:", config)
print("Our ESM config:", config)
model = EsmForSequenceClassification(config) if classification_head else EsmForMaskedLM(config)
if model.startswith("esmfold"):
model_class = EsmForProteinFolding
elif classification_head:
model_class = EsmForSequenceClassification
else:
model_class = EsmForMaskedLM
model = model_class(config)
model.eval()
# Now let's copy all the weights.
# Embeddings
model.esm.embeddings.word_embeddings.weight = esm_sent_encoder.embed_tokens.weight
model.esm.embeddings.word_embeddings.weight = original_esm_model.embed_tokens.weight
if position_embedding_type == "absolute":
model.esm.embeddings.position_embeddings.weight = esm_sent_encoder.embed_positions.weight
model.esm.embeddings.position_embeddings.weight = original_esm_model.embed_positions.weight
if config.emb_layer_norm_before:
model.esm.embeddings.layer_norm.weight = esm_sent_encoder.emb_layer_norm_before.weight
model.esm.embeddings.layer_norm.bias = esm_sent_encoder.emb_layer_norm_before.bias
model.esm.embeddings.layer_norm.weight = original_esm_model.emb_layer_norm_before.weight
model.esm.embeddings.layer_norm.bias = original_esm_model.emb_layer_norm_before.bias
model.esm.encoder.emb_layer_norm_after.weight = esm_sent_encoder.emb_layer_norm_after.weight
model.esm.encoder.emb_layer_norm_after.bias = esm_sent_encoder.emb_layer_norm_after.bias
model.esm.encoder.emb_layer_norm_after.weight = original_esm_model.emb_layer_norm_after.weight
model.esm.encoder.emb_layer_norm_after.bias = original_esm_model.emb_layer_norm_after.bias
for i in range(config.num_hidden_layers):
# Encoder: start of layer
layer: EsmLayer = model.esm.encoder.layer[i]
# esm_layer: TransformerSentenceEncoderLayer = esm_sent_encoder.layers[i]
esm_layer = esm_sent_encoder.layers[i]
# esm_layer: TransformerSentenceEncoderLayer = original_esm_model.layers[i]
esm_layer = original_esm_model.layers[i]
# self attention
self_attn: EsmSelfAttention = layer.attention.self
@ -183,7 +237,17 @@ def convert_esm_checkpoint_to_pytorch(
bert_output.dense.bias = esm_layer.fc2.bias
# end of layer
if classification_head:
if is_folding_model:
model.esm_s_combine.data = esm.esm_s_combine.data
transfer_and_check_weights(esm.embedding, model.embedding)
transfer_and_check_weights(esm.esm_s_mlp, model.esm_s_mlp)
transfer_and_check_weights(esm.trunk, model.trunk)
transfer_and_check_weights(esm.distogram_head, model.distogram_head)
transfer_and_check_weights(esm.ptm_head, model.ptm_head)
transfer_and_check_weights(esm.lm_head, model.lm_head)
transfer_and_check_weights(esm.lddt_head, model.lddt_head)
elif classification_head:
model.classifier.dense.weight = esm.esm.classification_heads["mnli"].dense.weight
model.classifier.dense.bias = esm.classification_heads["mnli"].dense.bias
model.classifier.out_proj.weight = esm.classification_heads["mnli"].out_proj.weight
@ -195,15 +259,19 @@ def convert_esm_checkpoint_to_pytorch(
model.lm_head.layer_norm.weight = esm.lm_head.layer_norm.weight
model.lm_head.layer_norm.bias = esm.lm_head.layer_norm.bias
model.lm_head.decoder.weight = esm.lm_head.weight
model.lm_head.decoder.bias = esm.lm_head.bias
model.lm_head.bias = esm.lm_head.bias
# Let's check that we get the same results.
batch_converter = alphabet.get_batch_converter()
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
if is_folding_model:
# Folding models aren't trained on masked inputs and don't like mask tokens.
sample_data = SAMPLE_DATA[:2]
else:
sample_data = SAMPLE_DATA
batch_labels, batch_strs, batch_tokens = batch_converter(SAMPLE_DATA)
batch_labels, batch_strs, batch_tokens = batch_converter(sample_data)
# Prepare tokenizer and make sure it matches
with TemporaryDirectory() as tempdir:
vocab = "\n".join(alphabet.all_toks)
@ -211,32 +279,66 @@ def convert_esm_checkpoint_to_pytorch(
vocab_file.write_text(vocab)
hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file))
hf_tokens = hf_tokenizer([row[1] for row in SAMPLE_DATA], return_tensors="pt", padding=True)
hf_tokens = hf_tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True)
success = torch.all(hf_tokens["input_ids"] == batch_tokens)
print("Do both models tokenizers output the same tokens?", "🔥" if success else "💩")
if not success:
raise Exception("Tokenization does not match!")
with torch.no_grad():
our_output = model(**hf_tokens, output_hidden_states=True)
our_output = our_output["logits"]
if classification_head:
their_output = esm.model.classification_heads["mnli"](esm.extract_features(batch_tokens))
if is_folding_model:
# Let's test the model in parts
# ESMFold always converts the ESM stem to float16, which requires float16 ops
# that don't exist on CPU. Therefore, to test it we need to run it on GPU. However,
# ESMFold is what we in the community call a "big boy" and so we desperately avoid putting both the
# original and the converted model on the GPU at the same time.
our_output = model.cuda()(
input_ids=hf_tokens["input_ids"].cuda(), attention_mask=hf_tokens["attention_mask"].cuda()
)
their_output = esm.cuda()(hf_tokens["input_ids"].cuda(), hf_tokens["attention_mask"].cuda())
else:
their_output = esm(batch_tokens, repr_layers=list(range(999)))
their_output = their_output["logits"]
print(our_output.shape, their_output.shape)
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-5
success = torch.allclose(our_output, their_output, atol=3e-4)
print("Do both models output the same tensors?", "🔥" if success else "💩")
our_output = model(**hf_tokens, output_hidden_states=True)
our_output = our_output["logits"]
if classification_head:
their_output = esm.model.classification_heads["mnli"](esm.extract_features(batch_tokens))
else:
their_output = esm(hf_tokens["input_ids"], repr_layers=list(range(999)))
their_output = their_output["logits"]
if not success:
raise Exception("Something went wRoNg")
if is_folding_model:
max_absolute_diff = torch.max(torch.abs(our_output["positions"] - their_output["positions"])).item()
success = torch.allclose(our_output["positions"], their_output["positions"], atol=1e-5)
else:
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
success = torch.allclose(our_output, their_output, atol=1e-5)
pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-5
print("Do both models output the same tensors?", "🔥" if success else "💩")
if not success:
raise Exception("Something went wRoNg")
pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
reloaded = model_class.from_pretrained(pytorch_dump_folder_path).cuda()
reloaded_output = reloaded(
input_ids=hf_tokens["input_ids"].cuda(), attention_mask=hf_tokens["attention_mask"].cuda()
)
if is_folding_model:
max_absolute_diff = torch.max(torch.abs(our_output["positions"] - reloaded_output["positions"])).item()
success = torch.allclose(our_output["positions"], their_output["positions"], atol=1e-6)
else:
max_absolute_diff = torch.max(torch.abs(our_output - reloaded_output["logits"])).item()
success = torch.allclose(our_output, reloaded_output["logits"], atol=1e-6)
print(f"max_absolute_diff = {max_absolute_diff}")
print("Does the model output the same tensors after reloading?", "🔥" if success else "💩")
if not success:
raise Exception("Something went wRoNg")
print(f"Saving tokenizer to {pytorch_dump_folder_path}")
hf_tokenizer.save_pretrained(pytorch_dump_folder_path)

View File

@ -14,6 +14,7 @@
# limitations under the License.
""" PyTorch ESM model."""
import math
from typing import List, Optional, Tuple, Union
import torch
@ -21,7 +22,6 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, gelu
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
@ -30,12 +30,7 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
from .configuration_esm import EsmConfig
@ -66,6 +61,13 @@ def apply_rotary_pos_emb(x, cos, sin):
return (x * cos) + (rotate_half(x) * sin)
def gelu(x):
"""
This is the gelu implementation from the original ESM repo. Using F.gelu yields subtly wrong results.
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
class RotaryEmbedding(torch.nn.Module):
"""
Rotary position embeddings based on those in
@ -163,7 +165,9 @@ class EsmEmbeddings(nn.Module):
mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
src_lengths = attention_mask.sum(-1)
mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
embeddings = embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
embeddings.dtype
)
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
@ -172,7 +176,7 @@ class EsmEmbeddings(nn.Module):
if self.layer_norm is not None:
embeddings = self.layer_norm(embeddings)
if attention_mask is not None:
embeddings = embeddings * attention_mask.unsqueeze(-1)
embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
# Matt: I think this line was copied incorrectly from BERT, disabling it for now.
# embeddings = self.dropout(embeddings)
return embeddings
@ -398,19 +402,14 @@ class EsmAttention(nn.Module):
return outputs
# Copied from transformers.models.bert.modeling_bert.BertIntermediate
class EsmIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = gelu(hidden_states)
return hidden_states
@ -497,15 +496,13 @@ class EsmLayer(nn.Module):
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
layer_output = self.feed_forward_chunk(attention_output)
outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,8 @@
# flake8: noqa
from .chunk_utils import chunk_layer
from .data_transforms import make_atom14_masks
from .feats import atom14_to_atom37, frames_and_literature_positions_to_atom14_pos, torsion_angles_to_frames
from .loss import compute_predicted_aligned_error, compute_tm
from .protein import Protein as OFProtein
from .protein import to_pdb
from .rigid_utils import Rigid, Rotation

View File

@ -0,0 +1,398 @@
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from functools import partial
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
import torch
from .tensor_utils import tensor_tree_map, tree_map
def _fetch_dims(tree):
shapes = []
tree_type = type(tree)
if tree_type is dict:
for v in tree.values():
shapes.extend(_fetch_dims(v))
elif tree_type is list or tree_type is tuple:
for t in tree:
shapes.extend(_fetch_dims(t))
elif tree_type is torch.Tensor:
shapes.append(tree.shape)
else:
raise ValueError("Not supported")
return shapes
@torch.jit.ignore
def _flat_idx_to_idx(
flat_idx: int,
dims: Tuple[int],
) -> Tuple[int]:
idx = []
for d in reversed(dims):
idx.append(flat_idx % d)
flat_idx = flat_idx // d
return tuple(reversed(idx))
@torch.jit.ignore
def _get_minimal_slice_set(
start: Sequence[int],
end: Sequence[int],
dims: int,
start_edges: Optional[Sequence[bool]] = None,
end_edges: Optional[Sequence[bool]] = None,
) -> Sequence[Tuple[int]]:
"""
Produces an ordered sequence of tensor slices that, when used in sequence on a tensor with shape dims, yields
tensors that contain every leaf in the contiguous range [start, end]. Care is taken to yield a short sequence of
slices, and perhaps even the shortest possible (I'm pretty sure it's the latter).
end is INCLUSIVE.
"""
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def reduce_edge_list(l):
tally = 1
for i in range(len(l)):
reversed_idx = -1 * (i + 1)
l[reversed_idx] *= tally
tally = l[reversed_idx]
if start_edges is None:
start_edges = [s == 0 for s in start]
reduce_edge_list(start_edges)
if end_edges is None:
end_edges = [e == (d - 1) for e, d in zip(end, dims)]
reduce_edge_list(end_edges)
# Base cases. Either start/end are empty and we're done, or the final,
# one-dimensional tensor can be simply sliced
if len(start) == 0:
return [tuple()]
elif len(start) == 1:
return [(slice(start[0], end[0] + 1),)]
slices = []
path = []
# Dimensions common to start and end can be selected directly
for s, e in zip(start, end):
if s == e:
path.append(slice(s, s + 1))
else:
break
path = tuple(path)
divergence_idx = len(path)
# start == end, and we're done
if divergence_idx == len(dims):
return [tuple(path)]
def upper():
sdi = start[divergence_idx]
return [
path + (slice(sdi, sdi + 1),) + s
for s in _get_minimal_slice_set(
start[divergence_idx + 1 :],
[d - 1 for d in dims[divergence_idx + 1 :]],
dims[divergence_idx + 1 :],
start_edges=start_edges[divergence_idx + 1 :],
end_edges=[1 for _ in end_edges[divergence_idx + 1 :]],
)
]
def lower():
edi = end[divergence_idx]
return [
path + (slice(edi, edi + 1),) + s
for s in _get_minimal_slice_set(
[0 for _ in start[divergence_idx + 1 :]],
end[divergence_idx + 1 :],
dims[divergence_idx + 1 :],
start_edges=[1 for _ in start_edges[divergence_idx + 1 :]],
end_edges=end_edges[divergence_idx + 1 :],
)
]
# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
if start_edges[divergence_idx] and end_edges[divergence_idx]:
slices.append(path + (slice(start[divergence_idx], end[divergence_idx] + 1),))
# If just start is at the edge, we can grab almost all of the subtree,
# treating only the ragged bottom edge as an edge case
elif start_edges[divergence_idx]:
slices.append(path + (slice(start[divergence_idx], end[divergence_idx]),))
slices.extend(lower())
# Analogous to the previous case, but the top is ragged this time
elif end_edges[divergence_idx]:
slices.extend(upper())
slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),))
# If both sides of the range are ragged, we need to handle both sides
# separately. If there's contiguous meat in between them, we can index it
# in one big chunk
else:
slices.extend(upper())
middle_ground = end[divergence_idx] - start[divergence_idx]
if middle_ground > 1:
slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx]),))
slices.extend(lower())
return [tuple(s) for s in slices]
@torch.jit.ignore
def _chunk_slice(
t: torch.Tensor,
flat_start: int,
flat_end: int,
no_batch_dims: int,
) -> torch.Tensor:
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the initial reshape call, which can be memory-intensive in certain situations. The only
reshape operations in this function are performed on sub-tensors that scale with (flat_end - flat_start), the chunk
size.
"""
batch_dims = t.shape[:no_batch_dims]
start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
# _get_minimal_slice_set is inclusive
end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
# Get an ordered list of slices to perform
slices = _get_minimal_slice_set(
start_idx,
end_idx,
batch_dims,
)
sliced_tensors = [t[s] for s in slices]
return torch.cat([s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors])
def chunk_layer(
layer: Callable,
inputs: Dict[str, Any],
chunk_size: int,
no_batch_dims: int,
low_mem: bool = False,
_out: Any = None,
_add_into_out: bool = False,
) -> Any:
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees," consisting only of (arbitrarily nested) lists, tuples,
and dicts with torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must be tensors and must share the same batch
dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch dimensions are specified, a "sub-batch" is defined
as a single indexing of all batch dimensions simultaneously (s.t. the number of sub-batches is the product
of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can be considered batch dimensions.
low_mem:
Avoids flattening potentially large input tensors. Unnecessary in most cases, and is ever so slightly
slower than the default setting.
Returns:
The reassembled output of the layer on the inputs.
"""
if not (len(inputs) > 0):
raise ValueError("Must provide at least one input")
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
def _prep_inputs(t):
if not low_mem:
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
t = t.reshape(-1, *t.shape[no_batch_dims:])
else:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
return t
prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
prepped_outputs = None
if _out is not None:
prepped_outputs = tensor_tree_map(lambda t: t.view([-1] + list(t.shape[no_batch_dims:])), _out)
flat_batch_dim = 1
for d in orig_batch_dims:
flat_batch_dim *= d
no_chunks = flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0)
def _select_chunk(t):
return t[i : i + chunk_size] if t.shape[0] != 1 else t
i = 0
out = prepped_outputs
for _ in range(no_chunks):
# Chunk the input
if not low_mem:
select_chunk = _select_chunk
else:
select_chunk = partial(
_chunk_slice,
flat_start=i,
flat_end=min(flat_batch_dim, i + chunk_size),
no_batch_dims=len(orig_batch_dims),
)
chunks = tensor_tree_map(select_chunk, prepped_inputs)
# Run the layer on the chunk
output_chunk = layer(**chunks)
# Allocate space for the output
if out is None:
out = tensor_tree_map(lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]), output_chunk)
# Put the chunk in its pre-allocated space
out_type = type(output_chunk)
if out_type is dict:
def assign(d1, d2):
for k, v in d1.items():
if type(v) is dict:
assign(v, d2[k])
else:
if _add_into_out:
v[i : i + chunk_size] += d2[k]
else:
v[i : i + chunk_size] = d2[k]
assign(out, output_chunk)
elif out_type is tuple:
for x1, x2 in zip(out, output_chunk):
if _add_into_out:
x1[i : i + chunk_size] += x2
else:
x1[i : i + chunk_size] = x2
elif out_type is torch.Tensor:
if _add_into_out:
out[i : i + chunk_size] += output_chunk
else:
out[i : i + chunk_size] = output_chunk
else:
raise ValueError("Not supported")
i += chunk_size
out = tensor_tree_map(lambda t: t.view(orig_batch_dims + t.shape[1:]), out)
return out
class ChunkSizeTuner:
def __init__(
self,
# Heuristically, runtimes for most of the modules in the network
# plateau earlier than this on all GPUs I've run the model on.
max_chunk_size=512,
):
self.max_chunk_size = max_chunk_size
self.cached_chunk_size = None
self.cached_arg_data = None
def _determine_favorable_chunk_size(self, fn, args, min_chunk_size):
logging.info("Tuning chunk size...")
if min_chunk_size >= self.max_chunk_size:
return min_chunk_size
candidates = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
candidates = [c for c in candidates if c > min_chunk_size]
candidates = [min_chunk_size] + candidates
candidates[-1] += 4
def test_chunk_size(chunk_size):
try:
with torch.no_grad():
fn(*args, chunk_size=chunk_size)
return True
except RuntimeError:
return False
min_viable_chunk_size_index = 0
i = len(candidates) - 1
while i > min_viable_chunk_size_index:
viable = test_chunk_size(candidates[i])
if not viable:
i = (min_viable_chunk_size_index + i) // 2
else:
min_viable_chunk_size_index = i
i = (i + len(candidates) - 1) // 2
return candidates[min_viable_chunk_size_index]
def _compare_arg_caches(self, ac1, ac2):
consistent = True
for a1, a2 in zip(ac1, ac2):
assert type(ac1) == type(ac2)
if type(ac1) is list or type(ac1) is tuple:
consistent &= self._compare_arg_caches(a1, a2)
elif type(ac1) is dict:
a1_items = [v for _, v in sorted(a1.items(), key=lambda x: x[0])]
a2_items = [v for _, v in sorted(a2.items(), key=lambda x: x[0])]
consistent &= self._compare_arg_caches(a1_items, a2_items)
else:
consistent &= a1 == a2
return consistent
def tune_chunk_size(
self,
representative_fn: Callable,
args: Tuple[Any],
min_chunk_size: int,
) -> int:
consistent = True
arg_data = tree_map(lambda a: a.shape if type(a) is torch.Tensor else a, args, object)
if self.cached_arg_data is not None:
# If args have changed shape/value, we need to re-tune
assert len(self.cached_arg_data) == len(arg_data)
consistent = self._compare_arg_caches(self.cached_arg_data, arg_data)
else:
# Otherwise, we can reuse the precomputed value
consistent = False
if not consistent:
self.cached_chunk_size = self._determine_favorable_chunk_size(
representative_fn,
args,
min_chunk_size,
)
self.cached_arg_data = arg_data
return self.cached_chunk_size

View File

@ -0,0 +1,92 @@
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from . import residue_constants as rc
from .tensor_utils import tensor_tree_map, tree_map
def make_atom14_masks(protein):
"""Construct denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = []
restype_atom37_to_atom14 = []
restype_atom14_mask = []
for rt in rc.restypes:
atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
restype_atom14_to_atom37.append([(rc.atom_order[name] if name else 0) for name in atom_names])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append(
[(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) for name in rc.atom_types]
)
restype_atom14_mask.append([(1.0 if name else 0.0) for name in atom_names])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.0] * 14)
restype_atom14_to_atom37 = torch.tensor(
restype_atom14_to_atom37,
dtype=torch.int32,
device=protein["aatype"].device,
)
restype_atom37_to_atom14 = torch.tensor(
restype_atom37_to_atom14,
dtype=torch.int32,
device=protein["aatype"].device,
)
restype_atom14_mask = torch.tensor(
restype_atom14_mask,
dtype=torch.float32,
device=protein["aatype"].device,
)
protein_aatype = protein["aatype"].to(torch.long)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
residx_atom14_mask = restype_atom14_mask[protein_aatype]
protein["atom14_atom_exists"] = residx_atom14_mask
protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
# create the gather indices for mapping back
residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
# create the corresponding mask
restype_atom37_mask = torch.zeros([21, 37], dtype=torch.float32, device=protein["aatype"].device)
for restype, restype_letter in enumerate(rc.restypes):
restype_name = rc.restype_1to3[restype_letter]
atom_names = rc.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = rc.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = restype_atom37_mask[protein_aatype]
protein["atom37_atom_exists"] = residx_atom37_mask
return protein
def make_atom14_masks_np(batch):
batch = tree_map(lambda n: torch.tensor(n, device=batch["aatype"].device), batch, np.ndarray)
out = make_atom14_masks(batch)
out = tensor_tree_map(lambda t: np.array(t), out)
return out

View File

@ -0,0 +1,234 @@
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from . import residue_constants as rc
from .rigid_utils import Rigid, Rotation
from .tensor_utils import batched_gather
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
is_gly = aatype == rc.restype_order["G"]
ca_idx = rc.atom_order["CA"]
cb_idx = rc.atom_order["CB"]
pseudo_beta = torch.where(
is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :],
)
if all_atom_masks is not None:
pseudo_beta_mask = torch.where(
is_gly,
all_atom_masks[..., ca_idx],
all_atom_masks[..., cb_idx],
)
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
def atom14_to_atom37(atom14, batch):
atom37_data = batched_gather(
atom14,
batch["residx_atom37_to_atom14"],
dim=-2,
no_batch_dims=len(atom14.shape[:-2]),
)
atom37_data = atom37_data * batch["atom37_atom_exists"][..., None]
return atom37_data
def build_template_angle_feat(template_feats):
template_aatype = template_feats["template_aatype"]
torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"]
torsion_angles_mask = template_feats["template_torsion_angles_mask"]
template_angle_feat = torch.cat(
[
nn.functional.one_hot(template_aatype, 22),
torsion_angles_sin_cos.reshape(*torsion_angles_sin_cos.shape[:-2], 14),
alt_torsion_angles_sin_cos.reshape(*alt_torsion_angles_sin_cos.shape[:-2], 14),
torsion_angles_mask,
],
dim=-1,
)
return template_angle_feat
def build_template_pair_feat(batch, min_bin, max_bin, no_bins, use_unit_vector=False, eps=1e-20, inf=1e8):
template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
# Compute distogram (this seems to differ slightly from Alg. 5)
tpb = batch["template_pseudo_beta"]
dgram = torch.sum((tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True)
lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
to_concat = [dgram, template_mask_2d[..., None]]
aatype_one_hot = nn.functional.one_hot(
batch["template_aatype"],
rc.restype_num + 2,
)
n_res = batch["template_aatype"].shape[-1]
to_concat.append(aatype_one_hot[..., None, :, :].expand(*aatype_one_hot.shape[:-2], n_res, -1, -1))
to_concat.append(aatype_one_hot[..., None, :].expand(*aatype_one_hot.shape[:-2], -1, n_res, -1))
n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
rigids = Rigid.make_transform_from_reference(
n_xyz=batch["template_all_atom_positions"][..., n, :],
ca_xyz=batch["template_all_atom_positions"][..., ca, :],
c_xyz=batch["template_all_atom_positions"][..., c, :],
eps=eps,
)
points = rigids.get_trans()[..., None, :, :]
rigid_vec = rigids[..., None].invert_apply(points)
inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1))
t_aa_masks = batch["template_all_atom_mask"]
template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
inv_distance_scalar = inv_distance_scalar * template_mask_2d
unit_vector = rigid_vec * inv_distance_scalar[..., None]
if not use_unit_vector:
unit_vector = unit_vector * 0.0
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None])
act = torch.cat(to_concat, dim=-1)
act = act * template_mask_2d[..., None]
return act
def build_extra_msa_feat(batch):
msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23)
msa_feat = [
msa_1hot,
batch["extra_has_deletion"].unsqueeze(-1),
batch["extra_deletion_value"].unsqueeze(-1),
]
return torch.cat(msa_feat, dim=-1)
def torsion_angles_to_frames(
r: Rigid,
alpha: torch.Tensor,
aatype: torch.Tensor,
rrgdf: torch.Tensor,
):
# [*, N, 8, 4, 4]
default_4x4 = rrgdf[aatype, ...]
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_r = r.from_tensor_4x4(default_4x4)
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
bb_rot[..., 1] = 1
# [*, N, 8, 2]
alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2)
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
# [
# [1, 0 , 0 ],
# [0, a_2,-a_1],
# [0, a_1, a_2]
# ]
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha
all_rots = Rigid(Rotation(rot_mats=all_rots), None)
all_frames = default_r.compose(all_rots)
chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6]
chi4_frame_to_frame = all_frames[..., 7]
chi1_frame_to_bb = all_frames[..., 4]
chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
all_frames_to_bb = Rigid.cat(
[
all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1),
chi3_frame_to_bb.unsqueeze(-1),
chi4_frame_to_bb.unsqueeze(-1),
],
dim=-1,
)
all_frames_to_global = r[..., None].compose(all_frames_to_bb)
return all_frames_to_global
def frames_and_literature_positions_to_atom14_pos(
r: Rigid,
aatype: torch.Tensor,
default_frames,
group_idx,
atom_mask,
lit_positions,
):
# [*, N, 14]
group_mask = group_idx[aatype, ...]
# [*, N, 14, 8]
group_mask = nn.functional.one_hot(
group_mask,
num_classes=default_frames.shape[-3],
)
# [*, N, 14, 8]
t_atoms_to_global = r[..., None, :] * group_mask
# [*, N, 14]
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
# [*, N, 14, 1]
atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
# [*, N, 14, 3]
lit_positions = lit_positions[aatype, ...]
pred_positions = t_atoms_to_global.apply(lit_positions)
pred_positions = pred_positions * atom_mask
return pred_positions

View File

@ -0,0 +1,105 @@
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Tuple
import torch
def _calculate_bin_centers(boundaries: torch.Tensor):
step = boundaries[1] - boundaries[0]
bin_centers = boundaries + step / 2
bin_centers = torch.cat([bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0)
return bin_centers
def _calculate_expected_aligned_error(
alignment_confidence_breaks: torch.Tensor,
aligned_distance_error_probs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
return (
torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
bin_centers[-1],
)
def compute_predicted_aligned_error(
logits: torch.Tensor,
max_bin: int = 31,
no_bins: int = 64,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Computes aligned confidence metrics from logits.
Args:
logits: [*, num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
max_bin: Maximum bin value
no_bins: Number of bins
Returns:
aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted
aligned error probabilities over bins for each residue pair.
predicted_aligned_error: [*, num_res, num_res] the expected aligned distance
error for each pair of residues.
max_predicted_aligned_error: [*] the maximum predicted error possible.
"""
boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device)
aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1)
(predicted_aligned_error, max_predicted_aligned_error,) = _calculate_expected_aligned_error(
alignment_confidence_breaks=boundaries,
aligned_distance_error_probs=aligned_confidence_probs,
)
return {
"aligned_confidence_probs": aligned_confidence_probs,
"predicted_aligned_error": predicted_aligned_error,
"max_predicted_aligned_error": max_predicted_aligned_error,
}
def compute_tm(
logits: torch.Tensor,
residue_weights: Optional[torch.Tensor] = None,
max_bin: int = 31,
no_bins: int = 64,
eps: float = 1e-8,
**kwargs,
) -> torch.Tensor:
if residue_weights is None:
residue_weights = logits.new_ones(logits.shape[-2])
boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device)
bin_centers = _calculate_bin_centers(boundaries)
torch.sum(residue_weights)
n = logits.shape[-2]
clipped_n = max(n, 19)
d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8
probs = torch.nn.functional.softmax(logits, dim=-1)
tm_per_bin = 1.0 / (1 + (bin_centers**2) / (d0**2))
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
normed_residue_mask = residue_weights / (eps + residue_weights.sum())
per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
weighted = per_alignment * residue_weights
argmax = (weighted == torch.max(weighted)).nonzero()[0]
return per_alignment[tuple(argmax)]

View File

@ -0,0 +1,329 @@
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Protein data type."""
import dataclasses
import re
import string
from typing import Any, Mapping, Optional, Sequence
import numpy as np
from . import residue_constants
FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any] # Is a nested dict.
PICO_TO_ANGSTROM = 0.01
@dataclasses.dataclass(frozen=True)
class Protein:
"""Protein structure representation."""
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
# Amino-acid type for each residue represented as an integer between 0 and
# 20, where 20 is 'X'.
aatype: np.ndarray # [num_res]
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
# is present and 0.0 if not. This should be used for loss masking.
atom_mask: np.ndarray # [num_res, num_atom_type]
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors: np.ndarray # [num_res, num_atom_type]
# Chain indices for multi-chain predictions
chain_index: Optional[np.ndarray] = None
# Optional remark about the protein. Included as a comment in output PDB
# files
remark: Optional[str] = None
# Templates used to generate this protein (prediction-only)
parents: Optional[Sequence[str]] = None
# Chain corresponding to each parent
parents_chain_index: Optional[Sequence[int]] = None
def from_proteinnet_string(proteinnet_str: str) -> Protein:
tag_re = r"(\[[A-Z]+\]\n)"
tags = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0]
groups = zip(tags[0::2], [l.split("\n") for l in tags[1::2]])
atoms = ["N", "CA", "C"]
aatype = None
atom_positions = None
atom_mask = None
for g in groups:
if "[PRIMARY]" == g[0]:
seq = g[1][0].strip()
for i in range(len(seq)):
if seq[i] not in residue_constants.restypes:
seq[i] = "X"
aatype = np.array(
[residue_constants.restype_order.get(res_symbol, residue_constants.restype_num) for res_symbol in seq]
)
elif "[TERTIARY]" == g[0]:
tertiary = []
for axis in range(3):
tertiary.append(list(map(float, g[1][axis].split())))
tertiary_np = np.array(tertiary)
atom_positions = np.zeros((len(tertiary[0]) // 3, residue_constants.atom_type_num, 3)).astype(np.float32)
for i, atom in enumerate(atoms):
atom_positions[:, residue_constants.atom_order[atom], :] = np.transpose(tertiary_np[:, i::3])
atom_positions *= PICO_TO_ANGSTROM
elif "[MASK]" == g[0]:
mask = np.array(list(map({"-": 0, "+": 1}.get, g[1][0].strip())))
atom_mask = np.zeros(
(
len(mask),
residue_constants.atom_type_num,
)
).astype(np.float32)
for i, atom in enumerate(atoms):
atom_mask[:, residue_constants.atom_order[atom]] = 1
atom_mask *= mask[..., None]
return Protein(
atom_positions=atom_positions,
atom_mask=atom_mask,
aatype=aatype,
residue_index=np.arange(len(aatype)),
b_factors=None,
)
def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
pdb_headers = []
remark = prot.remark
if remark is not None:
pdb_headers.append(f"REMARK {remark}")
parents = prot.parents
parents_chain_index = prot.parents_chain_index
if parents_chain_index is not None:
parents = [p for i, p in zip(parents_chain_index, parents) if i == chain_id]
if parents is None or len(parents) == 0:
parents = ["N/A"]
pdb_headers.append(f"PARENT {' '.join(parents)}")
return pdb_headers
def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
"""Add pdb headers to an existing PDB string. Useful during multi-chain
recycling
"""
out_pdb_lines = []
lines = pdb_str.split("\n")
remark = prot.remark
if remark is not None:
out_pdb_lines.append(f"REMARK {remark}")
parents_per_chain = None
if prot.parents is not None and len(prot.parents) > 0:
parents_per_chain = []
if prot.parents_chain_index is not None:
parent_dict = {}
for p, i in zip(prot.parents, prot.parents_chain_index):
parent_dict.setdefault(str(i), [])
parent_dict[str(i)].append(p)
max_idx = max([int(chain_idx) for chain_idx in parent_dict])
for i in range(max_idx + 1):
chain_parents = parent_dict.get(str(i), ["N/A"])
parents_per_chain.append(chain_parents)
else:
parents_per_chain.append(prot.parents)
else:
parents_per_chain = [["N/A"]]
def make_parent_line(p):
return f"PARENT {' '.join(p)}"
out_pdb_lines.append(make_parent_line(parents_per_chain[0]))
chain_counter = 0
for i, l in enumerate(lines):
if "PARENT" not in l and "REMARK" not in l:
out_pdb_lines.append(l)
if "TER" in l and "END" not in lines[i + 1]:
chain_counter += 1
if not chain_counter >= len(parents_per_chain):
chain_parents = parents_per_chain[chain_counter]
else:
chain_parents = ["N/A"]
out_pdb_lines.append(make_parent_line(chain_parents))
return "\n".join(out_pdb_lines)
def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string.
Args:
prot: The protein to convert to PDB.
Returns:
PDB string.
"""
restypes = residue_constants.restypes + ["X"]
def res_1to3(r):
return residue_constants.restype_1to3.get(restypes[r], "UNK")
atom_types = residue_constants.atom_types
pdb_lines = []
atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
b_factors = prot.b_factors
chain_index = prot.chain_index
if np.any(aatype > residue_constants.restype_num):
raise ValueError("Invalid aatypes.")
headers = get_pdb_headers(prot)
if len(headers) > 0:
pdb_lines.extend(headers)
n = aatype.shape[0]
atom_index = 1
prev_chain_index = 0
chain_tags = string.ascii_uppercase
# Add all atom sites.
for i in range(n):
res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
if mask < 0.5:
continue
record_type = "ATOM"
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
alt_loc = ""
insertion_code = ""
occupancy = 1.00
element = atom_name[0] # Protein supports only C, N, O, S, this works.
charge = ""
chain_tag = "A"
if chain_index is not None:
chain_tag = chain_tags[chain_index[i]]
# PDB is a columnar format, every space matters here!
atom_line = (
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
f"{res_name_3:>3} {chain_tag:>1}"
f"{residue_index[i]:>4}{insertion_code:>1} "
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
f"{occupancy:>6.2f}{b_factor:>6.2f} "
f"{element:>2}{charge:>2}"
)
pdb_lines.append(atom_line)
atom_index += 1
should_terminate = i == n - 1
if chain_index is not None:
if i != n - 1 and chain_index[i + 1] != prev_chain_index:
should_terminate = True
prev_chain_index = chain_index[i + 1]
if should_terminate:
# Close the chain.
chain_end = "TER"
chain_termination_line = (
f"{chain_end:<6}{atom_index:>5} {res_1to3(aatype[i]):>3} {chain_tag:>1}{residue_index[i]:>4}"
)
pdb_lines.append(chain_termination_line)
atom_index += 1
if i != n - 1:
# "prev" is a misnomer here. This happens at the beginning of
# each new chain.
pdb_lines.extend(get_pdb_headers(prot, prev_chain_index))
pdb_lines.append("END")
pdb_lines.append("")
return "\n".join(pdb_lines)
def ideal_atom_mask(prot: Protein) -> np.ndarray:
"""Computes an ideal atom mask.
`Protein.atom_mask` typically is defined according to the atoms that are reported in the PDB. This function
computes a mask according to heavy atoms that should be present in the given sequence of amino acids.
Args:
prot: `Protein` whose fields are `numpy.ndarray` objects.
Returns:
An ideal atom mask.
"""
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
def from_prediction(
features: FeatureDict,
result: ModelOutput,
b_factors: Optional[np.ndarray] = None,
chain_index: Optional[np.ndarray] = None,
remark: Optional[str] = None,
parents: Optional[Sequence[str]] = None,
parents_chain_index: Optional[Sequence[int]] = None,
) -> Protein:
"""Assembles a protein from a prediction.
Args:
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
chain_index: (Optional) Chain indices for multi-chain predictions
remark: (Optional) Remark about the prediction
parents: (Optional) List of template names
Returns:
A protein instance.
"""
if b_factors is None:
b_factors = np.zeros_like(result["final_atom_mask"])
return Protein(
aatype=features["aatype"],
atom_positions=result["final_atom_positions"],
atom_mask=result["final_atom_mask"],
residue_index=features["residue_index"] + 1,
b_factors=b_factors,
chain_index=chain_index,
remark=remark,
parents=parents,
parents_chain_index=parents_chain_index,
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,116 @@
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import List
import torch
import torch.nn as nn
def add(m1, m2, inplace):
# The first operation in a checkpoint can't be in-place, but it's
# nice to have in-place addition during inference. Thus...
if not inplace:
m1 = m1 + m2
else:
m1 += m2
return m1
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds])
def flatten_final_dims(t: torch.Tensor, no_dims: int):
return t.reshape(t.shape[:-no_dims] + (-1,))
def masked_mean(mask, value, dim, eps=1e-4):
mask = mask.expand(*value.shape)
return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
boundaries = torch.linspace(min_bin, max_bin, no_bins - 1, device=pts.device)
dists = torch.sqrt(torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1))
return torch.bucketize(dists, boundaries)
def dict_multimap(fn, dicts):
first = dicts[0]
new_dict = {}
for k, v in first.items():
all_v = [d[k] for d in dicts]
if type(v) is dict:
new_dict[k] = dict_multimap(fn, all_v)
else:
new_dict[k] = fn(all_v)
return new_dict
def one_hot(x, v_bins):
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
diffs = x[..., None] - reshaped_bins
am = torch.argmin(torch.abs(diffs), dim=-1)
return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
def batched_gather(data, inds, dim=0, no_batch_dims=0):
ranges = []
for i, s in enumerate(data.shape[:no_batch_dims]):
r = torch.arange(s)
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
ranges.append(r)
remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
ranges.extend(remaining_dims)
# Matt note: Editing this to get around the behaviour of using a list as an array index changing
# in recent Numpy versions
return data[tuple(ranges)]
# With tree_map, a poor man's JAX tree_map
def dict_map(fn, dic, leaf_type):
new_dict = {}
for k, v in dic.items():
if type(v) is dict:
new_dict[k] = dict_map(fn, v, leaf_type)
else:
new_dict[k] = tree_map(fn, v, leaf_type)
return new_dict
def tree_map(fn, tree, leaf_type):
if isinstance(tree, dict):
return dict_map(fn, tree, leaf_type)
elif isinstance(tree, list):
return [tree_map(fn, x, leaf_type) for x in tree]
elif isinstance(tree, tuple):
return tuple([tree_map(fn, x, leaf_type) for x in tree])
elif isinstance(tree, leaf_type):
return fn(tree)
else:
print(type(tree))
raise ValueError("Not supported")
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)

View File

@ -1985,6 +1985,13 @@ class ErniePreTrainedModel(metaclass=DummyObject):
ESM_PRETRAINED_MODEL_ARCHIVE_LIST = None
class EsmFoldPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class EsmForMaskedLM(metaclass=DummyObject):
_backends = ["torch"]
@ -1992,6 +1999,13 @@ class EsmForMaskedLM(metaclass=DummyObject):
requires_backends(self, ["torch"])
class EsmForProteinFolding(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class EsmForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -20,7 +20,6 @@ import unittest
from transformers import EsmConfig, is_torch_available
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
from ...generation.test_generation_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
@ -49,7 +48,7 @@ class EsmModelTester:
self.use_input_mask = True
self.use_token_type_ids = False
self.use_labels = True
self.vocab_size = 99
self.vocab_size = 33
self.hidden_size = 32
self.num_hidden_layers = 5
self.num_attention_heads = 4
@ -145,7 +144,7 @@ class EsmModelTester:
@require_torch
class EsmModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
class EsmModelTest(ModelTesterMixin, unittest.TestCase):
test_mismatched_shapes = False
@ -253,28 +252,32 @@ class EsmModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
class EsmModelIntegrationTest(TestCasePlus):
@slow
def test_inference_masked_lm(self):
model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0]
with torch.no_grad():
model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
model.eval()
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0]
vocab_size = 33
vocab_size = 33
expected_shape = torch.Size((1, 6, vocab_size))
self.assertEqual(output.shape, expected_shape)
expected_shape = torch.Size((1, 6, vocab_size))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[[15.0973, -6.6406, -1.1351], [-0.2209, -9.9622, 4.2109], [-1.6055, -10.0023, 1.5914]]]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
expected_slice = torch.tensor(
[[[15.0973, -6.6406, -1.1351], [-0.2209, -9.9622, 4.2109], [-1.6055, -10.0023, 1.5914]]]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
@slow
def test_inference_no_head(self):
model = EsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
with torch.no_grad():
model = EsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
model.eval()
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
output = model(input_ids)[0]
# compare the actual values for a slice.
expected_slice = torch.tensor(
[[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
output = model(input_ids)[0]
# compare the actual values for a slice.
expected_slice = torch.tensor(
[[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))

View File

@ -0,0 +1,234 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch ESM model. """
import unittest
from transformers import EsmConfig, is_torch_available
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
import torch
from transformers.models.esm.modeling_esmfold import EsmForProteinFolding
class EsmFoldModelTester:
def __init__(
self,
parent,
):
self.parent = parent
self.batch_size = 13
self.seq_length = 7
self.is_training = False
self.use_input_mask = True
self.use_token_type_ids = False
self.use_labels = False
self.vocab_size = 19
self.hidden_size = 32
self.num_hidden_layers = 5
self.num_attention_heads = 4
self.intermediate_size = 37
self.hidden_act = "gelu"
self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1
self.max_position_embeddings = 512
self.type_vocab_size = 16
self.type_sequence_label_size = 2
self.initializer_range = 0.02
self.num_labels = 3
self.num_choices = 4
self.scope = None
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config()
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
config = EsmConfig(
vocab_size=33,
hidden_size=self.hidden_size,
pad_token_id=1,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
is_folding_model=True,
esmfold_config={"trunk": {"num_blocks": 2}, "fp16_esm": False},
)
return config
def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = EsmForProteinFolding(config=config).float()
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
result = model(input_ids)
self.parent.assertEqual(result.positions.shape, (8, self.batch_size, self.seq_length, 14, 3))
self.parent.assertEqual(result.angles.shape, (8, self.batch_size, self.seq_length, 7, 2))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class EsmFoldModelTest(ModelTesterMixin, unittest.TestCase):
test_mismatched_shapes = False
all_model_classes = (EsmForProteinFolding,) if is_torch_available() else ()
all_generative_model_classes = ()
test_sequence_classification_problem_types = False
def setUp(self):
self.model_tester = EsmFoldModelTester(self)
self.config_tester = ConfigTester(self, config_class=EsmConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip("Does not support attention outputs")
def test_attention_outputs(self):
pass
@unittest.skip
def test_correct_missing_keys(self):
pass
@unittest.skip("Esm does not support embedding resizing")
def test_resize_embeddings_untied(self):
pass
@unittest.skip("Esm does not support embedding resizing")
def test_resize_tokens_embeddings(self):
pass
@unittest.skip("ESMFold does not support passing input embeds!")
def test_inputs_embeds(self):
pass
@unittest.skip("ESMFold does not support head pruning.")
def test_head_pruning(self):
pass
@unittest.skip("ESMFold does not support head pruning.")
def test_head_pruning_integration(self):
pass
@unittest.skip("ESMFold does not support head pruning.")
def test_head_pruning_save_load_from_config_init(self):
pass
@unittest.skip("ESMFold does not support head pruning.")
def test_head_pruning_save_load_from_pretrained(self):
pass
@unittest.skip("ESMFold does not support head pruning.")
def test_headmasking(self):
pass
@unittest.skip("ESMFold does not output hidden states in the normal way.")
def test_hidden_states_output(self):
pass
@unittest.skip("ESMfold does not output hidden states in the normal way.")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip("ESMFold only has one output format.")
def test_model_outputs_equivalence(self):
pass
@unittest.skip("This test doesn't work for ESMFold and doesn't test core functionality")
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip("ESMFold does not support input chunking.")
def test_feed_forward_chunking(self):
pass
@unittest.skip("ESMFold doesn't respect you and it certainly doesn't respect your initialization arguments.")
def test_initialization(self):
pass
@unittest.skip("ESMFold doesn't support torchscript compilation.")
def test_torchscript_output_attentions(self):
pass
@unittest.skip("ESMFold doesn't support torchscript compilation.")
def test_torchscript_output_hidden_state(self):
pass
@unittest.skip("ESMFold doesn't support torchscript compilation.")
def test_torchscript_simple(self):
pass
@unittest.skip("ESMFold doesn't support data parallel.")
def test_multi_gpu_data_parallel_forward(self):
pass
@require_torch
class EsmModelIntegrationTest(TestCasePlus):
@slow
def test_inference_protein_folding(self):
model = EsmForProteinFolding.from_pretrained("Rocketknight1/esmfold_v1").float()
model.eval()
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
position_outputs = model(input_ids)["positions"]
expected_slice = torch.tensor([2.5828, 0.7993, -10.9334], dtype=torch.float32)
self.assertTrue(torch.allclose(position_outputs[0, 0, 0, 0], expected_slice, atol=1e-4))

View File

@ -254,7 +254,7 @@ class TFEsmModelTest(TFModelTesterMixin, unittest.TestCase):
@require_tf
class TFEsmModelIntegrationTest(unittest.TestCase):
@slow
@unittest.skip("Temporarily disabled as we update ESM model checkpoints")
def test_inference_masked_lm(self):
model = TFEsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
@ -268,7 +268,7 @@ class TFEsmModelIntegrationTest(unittest.TestCase):
)
self.assertTrue(numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4))
@slow
@unittest.skip("Temporarily disabled as we update ESM model checkpoints")
def test_inference_no_head(self):
model = TFEsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")

View File

@ -268,6 +268,7 @@ def get_transformers_submodules():
IGNORE_SUBMODULES = [
"convert_pytorch_checkpoint_to_tf2",
"modeling_flax_pytorch_utils",
"models.esm.openfold_utils",
]

View File

@ -140,6 +140,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
# should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
# models to ignore for model xxx mapping
"EsmForProteinFolding",
"TimeSeriesTransformerForPrediction",
"PegasusXEncoder",
"PegasusXDecoder",