mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add ESMFold (#19977)
* initial commit * First draft that gets outputs without crashing! * Add all the ported openfold dependencies * testing * Restructure config files for ESMFold * Debugging to find output discrepancies * Mainly style * Make model runnable without extra deps * Remove utils and merge them to the modeling file * Use correct gelu and remove some debug prints * More cleanup * Update esm docs * Update conversion script to support ESMFold properly * Port some top-level changes from ESMFold repo * Expand EsmFold docstrings * Make attention_mask optional (default to all 1s) * Add inference test for ESMFold * Use config and not n kwargs * Add modeling output class * Remove einops * Remove chunking in ESM FFN * Update tests for ESMFold * Quality * REpo consistency * Remove tree dependency from ESMFold * make fixup * Add an error in case my structure map function breaks later * Remove needless code * Stop auto-casting the LM to float16 so CPU tests pass * Stop auto-casting the LM to float16 so CPU tests pass * Final test updates * Split test file * Copyright and quality * Unpin PyTorch to see built doc * Fix config file to_dict() method * Add some docstrings to the output * Skip TF checkpoint tests for ESM until we reupload those * make fixup * More docstrings * Unpin to get even with main * Flag example to write Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
This commit is contained in:
parent
4c9e0f029e
commit
7f9b7b3f0e
@ -14,8 +14,8 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
## Overview
|
||||
This page provides code and pre-trained weights for Transformer protein language models from Meta AI's Fundamental
|
||||
AI Research Team, providing the state-of-the-art ESM-2, and the previously released ESM-1b and ESM-1v. Transformer
|
||||
protein language models were introduced in the paper [Biological structure and function emerge from scaling
|
||||
AI Research Team, providing the state-of-the-art ESMFold and ESM-2, and the previously released ESM-1b and ESM-1v.
|
||||
Transformer protein language models were introduced in the paper [Biological structure and function emerge from scaling
|
||||
unsupervised learning to 250 million protein sequences](https://www.pnas.org/content/118/15/e2016239118) by
|
||||
Alexander Rives, Joshua Meier, Tom Sercu, Siddharth Goyal, Zeming Lin, Jason Liu, Demi Guo, Myle Ott,
|
||||
C. Lawrence Zitnick, Jerry Ma, and Rob Fergus.
|
||||
@ -27,6 +27,13 @@ It was released with the paper [Language models of protein sequences at the scal
|
||||
structure prediction](https://doi.org/10.1101/2022.07.20.500902) by Zeming Lin, Halil Akin, Roshan Rao, Brian Hie,
|
||||
Zhongkai Zhu, Wenting Lu, Allan dos Santos Costa, Maryam Fazel-Zarandi, Tom Sercu, Sal Candido and Alexander Rives.
|
||||
|
||||
Also introduced in this paper was ESMFold. It uses an ESM-2 stem with a head that can predict folded protein
|
||||
structures with state-of-the-art accuracy. Unlike [AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2),
|
||||
it relies on the token embeddings from the large pre-trained protein language model stem and does not perform a multiple
|
||||
sequence alignment (MSA) step at inference time, which means that ESMFold checkpoints are fully "standalone" -
|
||||
they do not require a database of known protein sequences and structures with associated external query tools
|
||||
to make predictions, and are much faster as a result.
|
||||
|
||||
|
||||
The abstract from
|
||||
"Biological structure and function emerge from scaling unsupervised learning to 250
|
||||
@ -63,17 +70,22 @@ order of magnitude faster than AlphaFold2, enabling exploration of the structura
|
||||
proteins in practical timescales.*
|
||||
|
||||
|
||||
|
||||
|
||||
Tips:
|
||||
|
||||
- ESM models are trained with a masked language modeling (MLM) objective.
|
||||
|
||||
The original code can be found [here](https://github.com/facebookresearch/esm) and was
|
||||
was developed by the Fundamental AI Research team at Meta AI.
|
||||
This model was contributed to huggingface by [jasonliu](https://huggingface.co/jasonliu)
|
||||
ESM-1b, ESM-1v and ESM-2 were contributed to huggingface by [jasonliu](https://huggingface.co/jasonliu)
|
||||
and [Matt](https://huggingface.co/Rocketknight1).
|
||||
|
||||
ESMFold was contributed to huggingface by [Matt](https://huggingface.co/Rocketknight1) and
|
||||
[Sylvain](https://huggingface.co/sgugger), with a big thank you to Nikita Smetanin, Roshan Rao and Tom Sercu for their
|
||||
help throughout the process!
|
||||
|
||||
The HuggingFace port of ESMFold uses portions of the [openfold](https://github.com/aqlaboratory/openfold) library.
|
||||
The `openfold` library is licensed under the Apache License 2.0.
|
||||
|
||||
## EsmConfig
|
||||
|
||||
[[autodoc]] EsmConfig
|
||||
@ -108,6 +120,11 @@ and [Matt](https://huggingface.co/Rocketknight1).
|
||||
[[autodoc]] EsmForTokenClassification
|
||||
- forward
|
||||
|
||||
## EsmForProteinFolding
|
||||
|
||||
[[autodoc]] EsmForProteinFolding
|
||||
- forward
|
||||
|
||||
## TFEsmModel
|
||||
|
||||
[[autodoc]] TFEsmModel
|
||||
|
@ -1265,7 +1265,9 @@ else:
|
||||
_import_structure["models.esm"].extend(
|
||||
[
|
||||
"ESM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"EsmFoldPreTrainedModel",
|
||||
"EsmForMaskedLM",
|
||||
"EsmForProteinFolding",
|
||||
"EsmForSequenceClassification",
|
||||
"EsmForTokenClassification",
|
||||
"EsmModel",
|
||||
@ -4144,7 +4146,9 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.esm import (
|
||||
ESM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
EsmFoldPreTrainedModel,
|
||||
EsmForMaskedLM,
|
||||
EsmForProteinFolding,
|
||||
EsmForSequenceClassification,
|
||||
EsmForTokenClassification,
|
||||
EsmModel,
|
||||
|
@ -39,6 +39,7 @@ else:
|
||||
"EsmModel",
|
||||
"EsmPreTrainedModel",
|
||||
]
|
||||
_import_structure["modeling_esmfold"] = ["EsmForProteinFolding", "EsmFoldPreTrainedModel"]
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
@ -55,7 +56,6 @@ else:
|
||||
"TFEsmPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_esm import ESM_PRETRAINED_CONFIG_ARCHIVE_MAP, EsmConfig
|
||||
from .tokenization_esm import EsmTokenizer
|
||||
@ -74,6 +74,7 @@ if TYPE_CHECKING:
|
||||
EsmModel,
|
||||
EsmPreTrainedModel,
|
||||
)
|
||||
from .modeling_esmfold import EsmFoldPreTrainedModel, EsmForProteinFolding
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
|
@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 Facebook and The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -14,12 +14,16 @@
|
||||
# limitations under the License.
|
||||
""" ESM model configuration"""
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Optional
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# TODO Update this
|
||||
ESM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"facebook/esm-1b": "https://huggingface.co/facebook/esm-1b/resolve/main/config.json",
|
||||
# See all ESM models at https://huggingface.co/models?filter=esm
|
||||
@ -118,9 +122,12 @@ class EsmConfig(PretrainedConfig):
|
||||
classifier_dropout=None,
|
||||
emb_layer_norm_before=None,
|
||||
token_dropout=False,
|
||||
is_folding_model=False,
|
||||
esmfold_config=None,
|
||||
vocab_list=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||
super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
@ -138,5 +145,225 @@ class EsmConfig(PretrainedConfig):
|
||||
self.classifier_dropout = classifier_dropout
|
||||
self.emb_layer_norm_before = emb_layer_norm_before
|
||||
self.token_dropout = token_dropout
|
||||
self.mask_token_id = mask_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.is_folding_model = is_folding_model
|
||||
if is_folding_model:
|
||||
if esmfold_config is None:
|
||||
logger.info("No esmfold_config supplied for folding model, using default values.")
|
||||
esmfold_config = EsmFoldConfig()
|
||||
elif isinstance(esmfold_config, dict):
|
||||
esmfold_config = EsmFoldConfig(**esmfold_config)
|
||||
self.esmfold_config = esmfold_config
|
||||
if vocab_list is None:
|
||||
logger.warning("No vocab_list supplied for folding model, assuming the ESM-2 vocabulary!")
|
||||
self.vocab_list = get_default_vocab_list()
|
||||
else:
|
||||
self.vocab_list = vocab_list
|
||||
else:
|
||||
self.esmfold_config = None
|
||||
self.vocab_list = None
|
||||
if self.esmfold_config is not None and getattr(self.esmfold_config, "use_esm_attn_map", False):
|
||||
raise ValueError("The HuggingFace port of ESMFold does not support use_esm_attn_map at this time!")
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
||||
|
||||
Returns:
|
||||
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = super().to_dict()
|
||||
if isinstance(self.esmfold_config, EsmFoldConfig):
|
||||
output["esmfold_config"] = self.esmfold_config.to_dict()
|
||||
return output
|
||||
|
||||
|
||||
@dataclass
|
||||
class EsmFoldConfig:
|
||||
esm_type: str = None
|
||||
fp16_esm: bool = True
|
||||
use_esm_attn_map: bool = False
|
||||
esm_ablate_pairwise: bool = False
|
||||
esm_ablate_sequence: bool = False
|
||||
esm_input_dropout: float = 0
|
||||
|
||||
embed_aa: bool = True
|
||||
bypass_lm: bool = False
|
||||
|
||||
lddt_head_hid_dim: int = 128
|
||||
trunk: "TrunkConfig" = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.trunk is None:
|
||||
self.trunk = TrunkConfig()
|
||||
elif isinstance(self.trunk, dict):
|
||||
self.trunk = TrunkConfig(**self.trunk)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
||||
|
||||
Returns:
|
||||
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = asdict(self)
|
||||
output["trunk"] = self.trunk.to_dict()
|
||||
return output
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrunkConfig:
|
||||
num_blocks: int = 48
|
||||
sequence_state_dim: int = 1024
|
||||
pairwise_state_dim: int = 128
|
||||
sequence_head_width: int = 32
|
||||
pairwise_head_width: int = 32
|
||||
position_bins: int = 32
|
||||
dropout: float = 0
|
||||
layer_drop: float = 0
|
||||
cpu_grad_checkpoint: bool = False
|
||||
max_recycles: int = 4
|
||||
chunk_size: Optional[int] = 128
|
||||
structure_module: "StructureModuleConfig" = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.structure_module is None:
|
||||
self.structure_module = StructureModuleConfig()
|
||||
elif isinstance(self.structure_module, dict):
|
||||
self.structure_module = StructureModuleConfig(**self.structure_module)
|
||||
|
||||
if self.max_recycles <= 0:
|
||||
raise ValueError(f"`max_recycles` should be positive, got {self.max_recycles}.")
|
||||
if self.sequence_state_dim % self.sequence_state_dim != 0:
|
||||
raise ValueError(
|
||||
"`sequence_state_dim` should be a round multiple of `sequence_state_dim`, got"
|
||||
f" {self.sequence_state_dim} and {self.sequence_state_dim}."
|
||||
)
|
||||
if self.pairwise_state_dim % self.pairwise_state_dim != 0:
|
||||
raise ValueError(
|
||||
"`pairwise_state_dim` should be a round multiple of `pairwise_state_dim`, got"
|
||||
f" {self.pairwise_state_dim} and {self.pairwise_state_dim}."
|
||||
)
|
||||
|
||||
sequence_num_heads = self.sequence_state_dim // self.sequence_head_width
|
||||
pairwise_num_heads = self.pairwise_state_dim // self.pairwise_head_width
|
||||
|
||||
if self.sequence_state_dim != sequence_num_heads * self.sequence_head_width:
|
||||
raise ValueError(
|
||||
"`sequence_state_dim` should be equal to `sequence_num_heads * sequence_head_width, got"
|
||||
f" {self.sequence_state_dim} != {sequence_num_heads} * {self.sequence_head_width}."
|
||||
)
|
||||
if self.pairwise_state_dim != pairwise_num_heads * self.pairwise_head_width:
|
||||
raise ValueError(
|
||||
"`pairwise_state_dim` should be equal to `pairwise_num_heads * pairwise_head_width, got"
|
||||
f" {self.pairwise_state_dim} != {pairwise_num_heads} * {self.pairwise_head_width}."
|
||||
)
|
||||
if self.pairwise_state_dim % 2 != 0:
|
||||
raise ValueError(f"`pairwise_state_dim` should be even, got {self.pairwise_state_dim}.")
|
||||
|
||||
if self.dropout >= 0.4:
|
||||
raise ValueError(f"`dropout` should not be greater than 0.4, got {self.dropout}.")
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
||||
|
||||
Returns:
|
||||
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = asdict(self)
|
||||
output["structure_module"] = self.structure_module.to_dict()
|
||||
return output
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructureModuleConfig:
|
||||
"""
|
||||
Args:
|
||||
sequence_dim:
|
||||
Single representation channel dimension
|
||||
pairwise_dim:
|
||||
Pair representation channel dimension
|
||||
ipa_dim:
|
||||
IPA hidden channel dimension
|
||||
resnet_dim:
|
||||
Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
|
||||
num_heads_ipa:
|
||||
Number of IPA heads
|
||||
num_qk_points:
|
||||
Number of query/key points to generate during IPA
|
||||
num_v_points:
|
||||
Number of value points to generate during IPA
|
||||
dropout_rate:
|
||||
Dropout rate used throughout the layer
|
||||
num_blocks:
|
||||
Number of structure module blocks
|
||||
num_transition_layers:
|
||||
Number of layers in the single representation transition (Alg. 23 lines 8-9)
|
||||
num_resnet_blocks:
|
||||
Number of blocks in the angle resnet
|
||||
num_angles:
|
||||
Number of angles to generate in the angle resnet
|
||||
trans_scale_factor:
|
||||
Scale of single representation transition hidden dimension
|
||||
epsilon:
|
||||
Small number used in angle resnet normalization
|
||||
inf:
|
||||
Large number used for attention masking
|
||||
"""
|
||||
|
||||
sequence_dim: int = 384
|
||||
pairwise_dim: int = 128
|
||||
ipa_dim: int = 16
|
||||
resnet_dim: int = 128
|
||||
num_heads_ipa: int = 12
|
||||
num_qk_points: int = 4
|
||||
num_v_points: int = 8
|
||||
dropout_rate: float = 0.1
|
||||
num_blocks: int = 8
|
||||
num_transition_layers: int = 1
|
||||
num_resnet_blocks: int = 2
|
||||
num_angles: int = 7
|
||||
trans_scale_factor: int = 10
|
||||
epsilon: float = 1e-8
|
||||
inf: float = 1e5
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
|
||||
def get_default_vocab_list():
|
||||
return (
|
||||
"<cls>",
|
||||
"<pad>",
|
||||
"<eos>",
|
||||
"<unk>",
|
||||
"L",
|
||||
"A",
|
||||
"G",
|
||||
"V",
|
||||
"S",
|
||||
"E",
|
||||
"R",
|
||||
"T",
|
||||
"I",
|
||||
"D",
|
||||
"P",
|
||||
"K",
|
||||
"Q",
|
||||
"N",
|
||||
"F",
|
||||
"Y",
|
||||
"M",
|
||||
"H",
|
||||
"W",
|
||||
"C",
|
||||
"X",
|
||||
"B",
|
||||
"U",
|
||||
"Z",
|
||||
"O",
|
||||
".",
|
||||
"-",
|
||||
"<null_1>",
|
||||
"<mask>",
|
||||
)
|
||||
|
@ -23,7 +23,8 @@ from tempfile import TemporaryDirectory
|
||||
import torch
|
||||
|
||||
import esm as esm_module
|
||||
from transformers.models.esm.configuration_esm import EsmConfig
|
||||
from esm.esmfold.v1.pretrained import esmfold_v1
|
||||
from transformers.models.esm.configuration_esm import EsmConfig, EsmFoldConfig
|
||||
from transformers.models.esm.modeling_esm import (
|
||||
EsmForMaskedLM,
|
||||
EsmForSequenceClassification,
|
||||
@ -33,6 +34,7 @@ from transformers.models.esm.modeling_esm import (
|
||||
EsmSelfAttention,
|
||||
EsmSelfOutput,
|
||||
)
|
||||
from transformers.models.esm.modeling_esmfold import EsmForProteinFolding
|
||||
from transformers.models.esm.tokenization_esm import EsmTokenizer
|
||||
from transformers.utils import logging
|
||||
|
||||
@ -60,19 +62,51 @@ MODEL_MAPPING = {
|
||||
"esm2_t30_150M_UR50D": esm_module.pretrained.esm2_t30_150M_UR50D,
|
||||
"esm2_t12_35M_UR50D": esm_module.pretrained.esm2_t12_35M_UR50D,
|
||||
"esm2_t6_8M_UR50D": esm_module.pretrained.esm2_t6_8M_UR50D,
|
||||
"esmfold_v1": esmfold_v1,
|
||||
}
|
||||
|
||||
|
||||
def transfer_and_check_weights(original_module, our_module):
|
||||
status = our_module.load_state_dict(original_module.state_dict())
|
||||
if status.missing_keys:
|
||||
raise ValueError(f"Missing keys: {status.missing_keys}")
|
||||
if status.unexpected_keys:
|
||||
raise ValueError(f"Unexpected keys: {status.unexpected_keys}")
|
||||
|
||||
|
||||
def convert_esm_checkpoint_to_pytorch(
|
||||
model: str, pytorch_dump_folder_path: str, classification_head: bool, push_to_repo: str, auth_token: str
|
||||
):
|
||||
"""
|
||||
Copy/paste/tweak esm's weights to our BERT structure.
|
||||
"""
|
||||
esm, alphabet = MODEL_MAPPING[model]()
|
||||
if model.startswith("esmfold"):
|
||||
esm = MODEL_MAPPING[model]()
|
||||
alphabet = esm.esm.alphabet
|
||||
else:
|
||||
esm, alphabet = MODEL_MAPPING[model]()
|
||||
esm.eval() # disable dropout
|
||||
esm_sent_encoder = esm
|
||||
if hasattr(esm, "args"):
|
||||
|
||||
if model.startswith("esmfold"):
|
||||
embed_dim = esm.esm.embed_dim
|
||||
num_layers = esm.esm.num_layers
|
||||
num_attention_heads = esm.esm.attention_heads
|
||||
intermediate_size = 4 * embed_dim
|
||||
token_dropout = esm.esm.token_dropout
|
||||
emb_layer_norm_before = False # This code path does not exist in ESM-2
|
||||
position_embedding_type = "rotary"
|
||||
is_folding_model = True
|
||||
esmfold_config = EsmFoldConfig()
|
||||
for key, val in esm.cfg.items():
|
||||
if hasattr(esmfold_config, key) and key != "trunk":
|
||||
setattr(esmfold_config, key, val)
|
||||
for key, val in esm.cfg.trunk.items():
|
||||
if hasattr(esmfold_config.trunk, key) and key != "structure_module":
|
||||
setattr(esmfold_config.trunk, key, val)
|
||||
for key, val in esm.cfg.trunk.structure_module.items():
|
||||
if hasattr(esmfold_config.trunk.structure_module, key):
|
||||
setattr(esmfold_config.trunk.structure_module, key, val)
|
||||
elif hasattr(esm, "args"):
|
||||
# Indicates an ESM-1b or ESM-1v model
|
||||
embed_dim = esm.args.embed_dim
|
||||
num_layers = esm.args.layers
|
||||
@ -81,6 +115,8 @@ def convert_esm_checkpoint_to_pytorch(
|
||||
token_dropout = esm.args.token_dropout
|
||||
emb_layer_norm_before = True if esm.emb_layer_norm_before else False
|
||||
position_embedding_type = "absolute"
|
||||
is_folding_model = False
|
||||
esmfold_config = None
|
||||
else:
|
||||
# Indicates an ESM-2 model
|
||||
embed_dim = esm.embed_dim
|
||||
@ -90,9 +126,18 @@ def convert_esm_checkpoint_to_pytorch(
|
||||
token_dropout = esm.token_dropout
|
||||
emb_layer_norm_before = False # This code path does not exist in ESM-2
|
||||
position_embedding_type = "rotary"
|
||||
is_folding_model = False
|
||||
esmfold_config = None
|
||||
|
||||
vocab_list = tuple(alphabet.all_toks)
|
||||
|
||||
if is_folding_model:
|
||||
original_esm_model = esm.esm
|
||||
else:
|
||||
original_esm_model = esm
|
||||
|
||||
config = EsmConfig(
|
||||
vocab_size=esm_sent_encoder.embed_tokens.num_embeddings,
|
||||
vocab_size=original_esm_model.embed_tokens.num_embeddings,
|
||||
mask_token_id=alphabet.mask_idx,
|
||||
hidden_size=embed_dim,
|
||||
num_hidden_layers=num_layers,
|
||||
@ -102,36 +147,45 @@ def convert_esm_checkpoint_to_pytorch(
|
||||
layer_norm_eps=1e-5, # PyTorch default used in fairseq
|
||||
attention_probs_dropout_prob=0.0,
|
||||
hidden_dropout_prob=0.0,
|
||||
pad_token_id=esm.padding_idx,
|
||||
pad_token_id=alphabet.padding_idx,
|
||||
emb_layer_norm_before=emb_layer_norm_before,
|
||||
token_dropout=token_dropout,
|
||||
position_embedding_type=position_embedding_type,
|
||||
is_folding_model=is_folding_model,
|
||||
esmfold_config=esmfold_config,
|
||||
vocab_list=vocab_list,
|
||||
)
|
||||
if classification_head:
|
||||
config.num_labels = esm.classification_heads["mnli"].out_proj.weight.shape[0]
|
||||
print("Our BERT config:", config)
|
||||
print("Our ESM config:", config)
|
||||
|
||||
model = EsmForSequenceClassification(config) if classification_head else EsmForMaskedLM(config)
|
||||
if model.startswith("esmfold"):
|
||||
model_class = EsmForProteinFolding
|
||||
elif classification_head:
|
||||
model_class = EsmForSequenceClassification
|
||||
else:
|
||||
model_class = EsmForMaskedLM
|
||||
model = model_class(config)
|
||||
model.eval()
|
||||
|
||||
# Now let's copy all the weights.
|
||||
# Embeddings
|
||||
model.esm.embeddings.word_embeddings.weight = esm_sent_encoder.embed_tokens.weight
|
||||
model.esm.embeddings.word_embeddings.weight = original_esm_model.embed_tokens.weight
|
||||
if position_embedding_type == "absolute":
|
||||
model.esm.embeddings.position_embeddings.weight = esm_sent_encoder.embed_positions.weight
|
||||
model.esm.embeddings.position_embeddings.weight = original_esm_model.embed_positions.weight
|
||||
|
||||
if config.emb_layer_norm_before:
|
||||
model.esm.embeddings.layer_norm.weight = esm_sent_encoder.emb_layer_norm_before.weight
|
||||
model.esm.embeddings.layer_norm.bias = esm_sent_encoder.emb_layer_norm_before.bias
|
||||
model.esm.embeddings.layer_norm.weight = original_esm_model.emb_layer_norm_before.weight
|
||||
model.esm.embeddings.layer_norm.bias = original_esm_model.emb_layer_norm_before.bias
|
||||
|
||||
model.esm.encoder.emb_layer_norm_after.weight = esm_sent_encoder.emb_layer_norm_after.weight
|
||||
model.esm.encoder.emb_layer_norm_after.bias = esm_sent_encoder.emb_layer_norm_after.bias
|
||||
model.esm.encoder.emb_layer_norm_after.weight = original_esm_model.emb_layer_norm_after.weight
|
||||
model.esm.encoder.emb_layer_norm_after.bias = original_esm_model.emb_layer_norm_after.bias
|
||||
|
||||
for i in range(config.num_hidden_layers):
|
||||
# Encoder: start of layer
|
||||
layer: EsmLayer = model.esm.encoder.layer[i]
|
||||
# esm_layer: TransformerSentenceEncoderLayer = esm_sent_encoder.layers[i]
|
||||
esm_layer = esm_sent_encoder.layers[i]
|
||||
# esm_layer: TransformerSentenceEncoderLayer = original_esm_model.layers[i]
|
||||
esm_layer = original_esm_model.layers[i]
|
||||
|
||||
# self attention
|
||||
self_attn: EsmSelfAttention = layer.attention.self
|
||||
@ -183,7 +237,17 @@ def convert_esm_checkpoint_to_pytorch(
|
||||
bert_output.dense.bias = esm_layer.fc2.bias
|
||||
# end of layer
|
||||
|
||||
if classification_head:
|
||||
if is_folding_model:
|
||||
model.esm_s_combine.data = esm.esm_s_combine.data
|
||||
transfer_and_check_weights(esm.embedding, model.embedding)
|
||||
transfer_and_check_weights(esm.esm_s_mlp, model.esm_s_mlp)
|
||||
transfer_and_check_weights(esm.trunk, model.trunk)
|
||||
transfer_and_check_weights(esm.distogram_head, model.distogram_head)
|
||||
transfer_and_check_weights(esm.ptm_head, model.ptm_head)
|
||||
transfer_and_check_weights(esm.lm_head, model.lm_head)
|
||||
transfer_and_check_weights(esm.lddt_head, model.lddt_head)
|
||||
|
||||
elif classification_head:
|
||||
model.classifier.dense.weight = esm.esm.classification_heads["mnli"].dense.weight
|
||||
model.classifier.dense.bias = esm.classification_heads["mnli"].dense.bias
|
||||
model.classifier.out_proj.weight = esm.classification_heads["mnli"].out_proj.weight
|
||||
@ -195,15 +259,19 @@ def convert_esm_checkpoint_to_pytorch(
|
||||
model.lm_head.layer_norm.weight = esm.lm_head.layer_norm.weight
|
||||
model.lm_head.layer_norm.bias = esm.lm_head.layer_norm.bias
|
||||
model.lm_head.decoder.weight = esm.lm_head.weight
|
||||
model.lm_head.decoder.bias = esm.lm_head.bias
|
||||
model.lm_head.bias = esm.lm_head.bias
|
||||
|
||||
# Let's check that we get the same results.
|
||||
batch_converter = alphabet.get_batch_converter()
|
||||
|
||||
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
|
||||
if is_folding_model:
|
||||
# Folding models aren't trained on masked inputs and don't like mask tokens.
|
||||
sample_data = SAMPLE_DATA[:2]
|
||||
else:
|
||||
sample_data = SAMPLE_DATA
|
||||
|
||||
batch_labels, batch_strs, batch_tokens = batch_converter(SAMPLE_DATA)
|
||||
|
||||
batch_labels, batch_strs, batch_tokens = batch_converter(sample_data)
|
||||
# Prepare tokenizer and make sure it matches
|
||||
with TemporaryDirectory() as tempdir:
|
||||
vocab = "\n".join(alphabet.all_toks)
|
||||
@ -211,32 +279,66 @@ def convert_esm_checkpoint_to_pytorch(
|
||||
vocab_file.write_text(vocab)
|
||||
hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file))
|
||||
|
||||
hf_tokens = hf_tokenizer([row[1] for row in SAMPLE_DATA], return_tensors="pt", padding=True)
|
||||
hf_tokens = hf_tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True)
|
||||
success = torch.all(hf_tokens["input_ids"] == batch_tokens)
|
||||
print("Do both models tokenizers output the same tokens?", "🔥" if success else "💩")
|
||||
if not success:
|
||||
raise Exception("Tokenization does not match!")
|
||||
|
||||
with torch.no_grad():
|
||||
our_output = model(**hf_tokens, output_hidden_states=True)
|
||||
our_output = our_output["logits"]
|
||||
if classification_head:
|
||||
their_output = esm.model.classification_heads["mnli"](esm.extract_features(batch_tokens))
|
||||
if is_folding_model:
|
||||
# Let's test the model in parts
|
||||
# ESMFold always converts the ESM stem to float16, which requires float16 ops
|
||||
# that don't exist on CPU. Therefore, to test it we need to run it on GPU. However,
|
||||
# ESMFold is what we in the community call a "big boy" and so we desperately avoid putting both the
|
||||
# original and the converted model on the GPU at the same time.
|
||||
our_output = model.cuda()(
|
||||
input_ids=hf_tokens["input_ids"].cuda(), attention_mask=hf_tokens["attention_mask"].cuda()
|
||||
)
|
||||
their_output = esm.cuda()(hf_tokens["input_ids"].cuda(), hf_tokens["attention_mask"].cuda())
|
||||
else:
|
||||
their_output = esm(batch_tokens, repr_layers=list(range(999)))
|
||||
their_output = their_output["logits"]
|
||||
print(our_output.shape, their_output.shape)
|
||||
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
|
||||
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-5
|
||||
success = torch.allclose(our_output, their_output, atol=3e-4)
|
||||
print("Do both models output the same tensors?", "🔥" if success else "💩")
|
||||
our_output = model(**hf_tokens, output_hidden_states=True)
|
||||
our_output = our_output["logits"]
|
||||
if classification_head:
|
||||
their_output = esm.model.classification_heads["mnli"](esm.extract_features(batch_tokens))
|
||||
else:
|
||||
their_output = esm(hf_tokens["input_ids"], repr_layers=list(range(999)))
|
||||
their_output = their_output["logits"]
|
||||
|
||||
if not success:
|
||||
raise Exception("Something went wRoNg")
|
||||
if is_folding_model:
|
||||
max_absolute_diff = torch.max(torch.abs(our_output["positions"] - their_output["positions"])).item()
|
||||
success = torch.allclose(our_output["positions"], their_output["positions"], atol=1e-5)
|
||||
else:
|
||||
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
|
||||
success = torch.allclose(our_output, their_output, atol=1e-5)
|
||||
|
||||
pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
|
||||
print(f"Saving model to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-5
|
||||
print("Do both models output the same tensors?", "🔥" if success else "💩")
|
||||
|
||||
if not success:
|
||||
raise Exception("Something went wRoNg")
|
||||
|
||||
pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
|
||||
print(f"Saving model to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
reloaded = model_class.from_pretrained(pytorch_dump_folder_path).cuda()
|
||||
reloaded_output = reloaded(
|
||||
input_ids=hf_tokens["input_ids"].cuda(), attention_mask=hf_tokens["attention_mask"].cuda()
|
||||
)
|
||||
|
||||
if is_folding_model:
|
||||
max_absolute_diff = torch.max(torch.abs(our_output["positions"] - reloaded_output["positions"])).item()
|
||||
success = torch.allclose(our_output["positions"], their_output["positions"], atol=1e-6)
|
||||
else:
|
||||
max_absolute_diff = torch.max(torch.abs(our_output - reloaded_output["logits"])).item()
|
||||
success = torch.allclose(our_output, reloaded_output["logits"], atol=1e-6)
|
||||
|
||||
print(f"max_absolute_diff = {max_absolute_diff}")
|
||||
print("Does the model output the same tensors after reloading?", "🔥" if success else "💩")
|
||||
|
||||
if not success:
|
||||
raise Exception("Something went wRoNg")
|
||||
|
||||
print(f"Saving tokenizer to {pytorch_dump_folder_path}")
|
||||
hf_tokenizer.save_pretrained(pytorch_dump_folder_path)
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
""" PyTorch ESM model."""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -21,7 +22,6 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN, gelu
|
||||
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -30,12 +30,7 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import logging
|
||||
from .configuration_esm import EsmConfig
|
||||
|
||||
@ -66,6 +61,13 @@ def apply_rotary_pos_emb(x, cos, sin):
|
||||
return (x * cos) + (rotate_half(x) * sin)
|
||||
|
||||
|
||||
def gelu(x):
|
||||
"""
|
||||
This is the gelu implementation from the original ESM repo. Using F.gelu yields subtly wrong results.
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
"""
|
||||
Rotary position embeddings based on those in
|
||||
@ -163,7 +165,9 @@ class EsmEmbeddings(nn.Module):
|
||||
mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
|
||||
src_lengths = attention_mask.sum(-1)
|
||||
mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
|
||||
embeddings = embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
|
||||
embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
|
||||
embeddings.dtype
|
||||
)
|
||||
|
||||
if self.position_embedding_type == "absolute":
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
@ -172,7 +176,7 @@ class EsmEmbeddings(nn.Module):
|
||||
if self.layer_norm is not None:
|
||||
embeddings = self.layer_norm(embeddings)
|
||||
if attention_mask is not None:
|
||||
embeddings = embeddings * attention_mask.unsqueeze(-1)
|
||||
embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
|
||||
# Matt: I think this line was copied incorrectly from BERT, disabling it for now.
|
||||
# embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
@ -398,19 +402,14 @@ class EsmAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertIntermediate
|
||||
class EsmIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
hidden_states = gelu(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -497,15 +496,13 @@ class EsmLayer(nn.Module):
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
layer_output = self.feed_forward_chunk(attention_output)
|
||||
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
|
2307
src/transformers/models/esm/modeling_esmfold.py
Normal file
2307
src/transformers/models/esm/modeling_esmfold.py
Normal file
File diff suppressed because it is too large
Load Diff
8
src/transformers/models/esm/openfold_utils/__init__.py
Normal file
8
src/transformers/models/esm/openfold_utils/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# flake8: noqa
|
||||
from .chunk_utils import chunk_layer
|
||||
from .data_transforms import make_atom14_masks
|
||||
from .feats import atom14_to_atom37, frames_and_literature_positions_to_atom14_pos, torsion_angles_to_frames
|
||||
from .loss import compute_predicted_aligned_error, compute_tm
|
||||
from .protein import Protein as OFProtein
|
||||
from .protein import to_pdb
|
||||
from .rigid_utils import Rigid, Rotation
|
398
src/transformers/models/esm/openfold_utils/chunk_utils.py
Normal file
398
src/transformers/models/esm/openfold_utils/chunk_utils.py
Normal file
@ -0,0 +1,398 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .tensor_utils import tensor_tree_map, tree_map
|
||||
|
||||
|
||||
def _fetch_dims(tree):
|
||||
shapes = []
|
||||
tree_type = type(tree)
|
||||
if tree_type is dict:
|
||||
for v in tree.values():
|
||||
shapes.extend(_fetch_dims(v))
|
||||
elif tree_type is list or tree_type is tuple:
|
||||
for t in tree:
|
||||
shapes.extend(_fetch_dims(t))
|
||||
elif tree_type is torch.Tensor:
|
||||
shapes.append(tree.shape)
|
||||
else:
|
||||
raise ValueError("Not supported")
|
||||
|
||||
return shapes
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _flat_idx_to_idx(
|
||||
flat_idx: int,
|
||||
dims: Tuple[int],
|
||||
) -> Tuple[int]:
|
||||
idx = []
|
||||
for d in reversed(dims):
|
||||
idx.append(flat_idx % d)
|
||||
flat_idx = flat_idx // d
|
||||
|
||||
return tuple(reversed(idx))
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _get_minimal_slice_set(
|
||||
start: Sequence[int],
|
||||
end: Sequence[int],
|
||||
dims: int,
|
||||
start_edges: Optional[Sequence[bool]] = None,
|
||||
end_edges: Optional[Sequence[bool]] = None,
|
||||
) -> Sequence[Tuple[int]]:
|
||||
"""
|
||||
Produces an ordered sequence of tensor slices that, when used in sequence on a tensor with shape dims, yields
|
||||
tensors that contain every leaf in the contiguous range [start, end]. Care is taken to yield a short sequence of
|
||||
slices, and perhaps even the shortest possible (I'm pretty sure it's the latter).
|
||||
|
||||
end is INCLUSIVE.
|
||||
"""
|
||||
# start_edges and end_edges both indicate whether, starting from any given
|
||||
# dimension, the start/end index is at the top/bottom edge of the
|
||||
# corresponding tensor, modeled as a tree
|
||||
def reduce_edge_list(l):
|
||||
tally = 1
|
||||
for i in range(len(l)):
|
||||
reversed_idx = -1 * (i + 1)
|
||||
l[reversed_idx] *= tally
|
||||
tally = l[reversed_idx]
|
||||
|
||||
if start_edges is None:
|
||||
start_edges = [s == 0 for s in start]
|
||||
reduce_edge_list(start_edges)
|
||||
if end_edges is None:
|
||||
end_edges = [e == (d - 1) for e, d in zip(end, dims)]
|
||||
reduce_edge_list(end_edges)
|
||||
|
||||
# Base cases. Either start/end are empty and we're done, or the final,
|
||||
# one-dimensional tensor can be simply sliced
|
||||
if len(start) == 0:
|
||||
return [tuple()]
|
||||
elif len(start) == 1:
|
||||
return [(slice(start[0], end[0] + 1),)]
|
||||
|
||||
slices = []
|
||||
path = []
|
||||
|
||||
# Dimensions common to start and end can be selected directly
|
||||
for s, e in zip(start, end):
|
||||
if s == e:
|
||||
path.append(slice(s, s + 1))
|
||||
else:
|
||||
break
|
||||
|
||||
path = tuple(path)
|
||||
divergence_idx = len(path)
|
||||
|
||||
# start == end, and we're done
|
||||
if divergence_idx == len(dims):
|
||||
return [tuple(path)]
|
||||
|
||||
def upper():
|
||||
sdi = start[divergence_idx]
|
||||
return [
|
||||
path + (slice(sdi, sdi + 1),) + s
|
||||
for s in _get_minimal_slice_set(
|
||||
start[divergence_idx + 1 :],
|
||||
[d - 1 for d in dims[divergence_idx + 1 :]],
|
||||
dims[divergence_idx + 1 :],
|
||||
start_edges=start_edges[divergence_idx + 1 :],
|
||||
end_edges=[1 for _ in end_edges[divergence_idx + 1 :]],
|
||||
)
|
||||
]
|
||||
|
||||
def lower():
|
||||
edi = end[divergence_idx]
|
||||
return [
|
||||
path + (slice(edi, edi + 1),) + s
|
||||
for s in _get_minimal_slice_set(
|
||||
[0 for _ in start[divergence_idx + 1 :]],
|
||||
end[divergence_idx + 1 :],
|
||||
dims[divergence_idx + 1 :],
|
||||
start_edges=[1 for _ in start_edges[divergence_idx + 1 :]],
|
||||
end_edges=end_edges[divergence_idx + 1 :],
|
||||
)
|
||||
]
|
||||
|
||||
# If both start and end are at the edges of the subtree rooted at
|
||||
# divergence_idx, we can just select the whole subtree at once
|
||||
if start_edges[divergence_idx] and end_edges[divergence_idx]:
|
||||
slices.append(path + (slice(start[divergence_idx], end[divergence_idx] + 1),))
|
||||
# If just start is at the edge, we can grab almost all of the subtree,
|
||||
# treating only the ragged bottom edge as an edge case
|
||||
elif start_edges[divergence_idx]:
|
||||
slices.append(path + (slice(start[divergence_idx], end[divergence_idx]),))
|
||||
slices.extend(lower())
|
||||
# Analogous to the previous case, but the top is ragged this time
|
||||
elif end_edges[divergence_idx]:
|
||||
slices.extend(upper())
|
||||
slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),))
|
||||
# If both sides of the range are ragged, we need to handle both sides
|
||||
# separately. If there's contiguous meat in between them, we can index it
|
||||
# in one big chunk
|
||||
else:
|
||||
slices.extend(upper())
|
||||
middle_ground = end[divergence_idx] - start[divergence_idx]
|
||||
if middle_ground > 1:
|
||||
slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx]),))
|
||||
slices.extend(lower())
|
||||
|
||||
return [tuple(s) for s in slices]
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk_slice(
|
||||
t: torch.Tensor,
|
||||
flat_start: int,
|
||||
flat_end: int,
|
||||
no_batch_dims: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Equivalent to
|
||||
|
||||
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
|
||||
|
||||
but without the need for the initial reshape call, which can be memory-intensive in certain situations. The only
|
||||
reshape operations in this function are performed on sub-tensors that scale with (flat_end - flat_start), the chunk
|
||||
size.
|
||||
"""
|
||||
|
||||
batch_dims = t.shape[:no_batch_dims]
|
||||
start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
|
||||
# _get_minimal_slice_set is inclusive
|
||||
end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
|
||||
|
||||
# Get an ordered list of slices to perform
|
||||
slices = _get_minimal_slice_set(
|
||||
start_idx,
|
||||
end_idx,
|
||||
batch_dims,
|
||||
)
|
||||
|
||||
sliced_tensors = [t[s] for s in slices]
|
||||
|
||||
return torch.cat([s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors])
|
||||
|
||||
|
||||
def chunk_layer(
|
||||
layer: Callable,
|
||||
inputs: Dict[str, Any],
|
||||
chunk_size: int,
|
||||
no_batch_dims: int,
|
||||
low_mem: bool = False,
|
||||
_out: Any = None,
|
||||
_add_into_out: bool = False,
|
||||
) -> Any:
|
||||
"""
|
||||
Implements the "chunking" procedure described in section 1.11.8.
|
||||
|
||||
Layer outputs and inputs are assumed to be simple "pytrees," consisting only of (arbitrarily nested) lists, tuples,
|
||||
and dicts with torch.Tensor leaves.
|
||||
|
||||
Args:
|
||||
layer:
|
||||
The layer to be applied chunk-wise
|
||||
inputs:
|
||||
A (non-nested) dictionary of keyworded inputs. All leaves must be tensors and must share the same batch
|
||||
dimensions.
|
||||
chunk_size:
|
||||
The number of sub-batches per chunk. If multiple batch dimensions are specified, a "sub-batch" is defined
|
||||
as a single indexing of all batch dimensions simultaneously (s.t. the number of sub-batches is the product
|
||||
of the batch dimensions).
|
||||
no_batch_dims:
|
||||
How many of the initial dimensions of each input tensor can be considered batch dimensions.
|
||||
low_mem:
|
||||
Avoids flattening potentially large input tensors. Unnecessary in most cases, and is ever so slightly
|
||||
slower than the default setting.
|
||||
Returns:
|
||||
The reassembled output of the layer on the inputs.
|
||||
"""
|
||||
if not (len(inputs) > 0):
|
||||
raise ValueError("Must provide at least one input")
|
||||
|
||||
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
|
||||
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
|
||||
|
||||
def _prep_inputs(t):
|
||||
if not low_mem:
|
||||
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
|
||||
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
|
||||
t = t.reshape(-1, *t.shape[no_batch_dims:])
|
||||
else:
|
||||
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
|
||||
return t
|
||||
|
||||
prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
|
||||
prepped_outputs = None
|
||||
if _out is not None:
|
||||
prepped_outputs = tensor_tree_map(lambda t: t.view([-1] + list(t.shape[no_batch_dims:])), _out)
|
||||
|
||||
flat_batch_dim = 1
|
||||
for d in orig_batch_dims:
|
||||
flat_batch_dim *= d
|
||||
|
||||
no_chunks = flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0)
|
||||
|
||||
def _select_chunk(t):
|
||||
return t[i : i + chunk_size] if t.shape[0] != 1 else t
|
||||
|
||||
i = 0
|
||||
out = prepped_outputs
|
||||
for _ in range(no_chunks):
|
||||
# Chunk the input
|
||||
if not low_mem:
|
||||
select_chunk = _select_chunk
|
||||
else:
|
||||
select_chunk = partial(
|
||||
_chunk_slice,
|
||||
flat_start=i,
|
||||
flat_end=min(flat_batch_dim, i + chunk_size),
|
||||
no_batch_dims=len(orig_batch_dims),
|
||||
)
|
||||
|
||||
chunks = tensor_tree_map(select_chunk, prepped_inputs)
|
||||
|
||||
# Run the layer on the chunk
|
||||
output_chunk = layer(**chunks)
|
||||
|
||||
# Allocate space for the output
|
||||
if out is None:
|
||||
out = tensor_tree_map(lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]), output_chunk)
|
||||
|
||||
# Put the chunk in its pre-allocated space
|
||||
out_type = type(output_chunk)
|
||||
if out_type is dict:
|
||||
|
||||
def assign(d1, d2):
|
||||
for k, v in d1.items():
|
||||
if type(v) is dict:
|
||||
assign(v, d2[k])
|
||||
else:
|
||||
if _add_into_out:
|
||||
v[i : i + chunk_size] += d2[k]
|
||||
else:
|
||||
v[i : i + chunk_size] = d2[k]
|
||||
|
||||
assign(out, output_chunk)
|
||||
elif out_type is tuple:
|
||||
for x1, x2 in zip(out, output_chunk):
|
||||
if _add_into_out:
|
||||
x1[i : i + chunk_size] += x2
|
||||
else:
|
||||
x1[i : i + chunk_size] = x2
|
||||
elif out_type is torch.Tensor:
|
||||
if _add_into_out:
|
||||
out[i : i + chunk_size] += output_chunk
|
||||
else:
|
||||
out[i : i + chunk_size] = output_chunk
|
||||
else:
|
||||
raise ValueError("Not supported")
|
||||
|
||||
i += chunk_size
|
||||
|
||||
out = tensor_tree_map(lambda t: t.view(orig_batch_dims + t.shape[1:]), out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ChunkSizeTuner:
|
||||
def __init__(
|
||||
self,
|
||||
# Heuristically, runtimes for most of the modules in the network
|
||||
# plateau earlier than this on all GPUs I've run the model on.
|
||||
max_chunk_size=512,
|
||||
):
|
||||
self.max_chunk_size = max_chunk_size
|
||||
self.cached_chunk_size = None
|
||||
self.cached_arg_data = None
|
||||
|
||||
def _determine_favorable_chunk_size(self, fn, args, min_chunk_size):
|
||||
logging.info("Tuning chunk size...")
|
||||
|
||||
if min_chunk_size >= self.max_chunk_size:
|
||||
return min_chunk_size
|
||||
|
||||
candidates = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
|
||||
candidates = [c for c in candidates if c > min_chunk_size]
|
||||
candidates = [min_chunk_size] + candidates
|
||||
candidates[-1] += 4
|
||||
|
||||
def test_chunk_size(chunk_size):
|
||||
try:
|
||||
with torch.no_grad():
|
||||
fn(*args, chunk_size=chunk_size)
|
||||
return True
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
min_viable_chunk_size_index = 0
|
||||
i = len(candidates) - 1
|
||||
while i > min_viable_chunk_size_index:
|
||||
viable = test_chunk_size(candidates[i])
|
||||
if not viable:
|
||||
i = (min_viable_chunk_size_index + i) // 2
|
||||
else:
|
||||
min_viable_chunk_size_index = i
|
||||
i = (i + len(candidates) - 1) // 2
|
||||
|
||||
return candidates[min_viable_chunk_size_index]
|
||||
|
||||
def _compare_arg_caches(self, ac1, ac2):
|
||||
consistent = True
|
||||
for a1, a2 in zip(ac1, ac2):
|
||||
assert type(ac1) == type(ac2)
|
||||
if type(ac1) is list or type(ac1) is tuple:
|
||||
consistent &= self._compare_arg_caches(a1, a2)
|
||||
elif type(ac1) is dict:
|
||||
a1_items = [v for _, v in sorted(a1.items(), key=lambda x: x[0])]
|
||||
a2_items = [v for _, v in sorted(a2.items(), key=lambda x: x[0])]
|
||||
consistent &= self._compare_arg_caches(a1_items, a2_items)
|
||||
else:
|
||||
consistent &= a1 == a2
|
||||
|
||||
return consistent
|
||||
|
||||
def tune_chunk_size(
|
||||
self,
|
||||
representative_fn: Callable,
|
||||
args: Tuple[Any],
|
||||
min_chunk_size: int,
|
||||
) -> int:
|
||||
consistent = True
|
||||
arg_data = tree_map(lambda a: a.shape if type(a) is torch.Tensor else a, args, object)
|
||||
if self.cached_arg_data is not None:
|
||||
# If args have changed shape/value, we need to re-tune
|
||||
assert len(self.cached_arg_data) == len(arg_data)
|
||||
consistent = self._compare_arg_caches(self.cached_arg_data, arg_data)
|
||||
else:
|
||||
# Otherwise, we can reuse the precomputed value
|
||||
consistent = False
|
||||
|
||||
if not consistent:
|
||||
self.cached_chunk_size = self._determine_favorable_chunk_size(
|
||||
representative_fn,
|
||||
args,
|
||||
min_chunk_size,
|
||||
)
|
||||
self.cached_arg_data = arg_data
|
||||
|
||||
return self.cached_chunk_size
|
@ -0,0 +1,92 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import residue_constants as rc
|
||||
from .tensor_utils import tensor_tree_map, tree_map
|
||||
|
||||
|
||||
def make_atom14_masks(protein):
|
||||
"""Construct denser atom positions (14 dimensions instead of 37)."""
|
||||
restype_atom14_to_atom37 = []
|
||||
restype_atom37_to_atom14 = []
|
||||
restype_atom14_mask = []
|
||||
|
||||
for rt in rc.restypes:
|
||||
atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
|
||||
restype_atom14_to_atom37.append([(rc.atom_order[name] if name else 0) for name in atom_names])
|
||||
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
|
||||
restype_atom37_to_atom14.append(
|
||||
[(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) for name in rc.atom_types]
|
||||
)
|
||||
|
||||
restype_atom14_mask.append([(1.0 if name else 0.0) for name in atom_names])
|
||||
|
||||
# Add dummy mapping for restype 'UNK'
|
||||
restype_atom14_to_atom37.append([0] * 14)
|
||||
restype_atom37_to_atom14.append([0] * 37)
|
||||
restype_atom14_mask.append([0.0] * 14)
|
||||
|
||||
restype_atom14_to_atom37 = torch.tensor(
|
||||
restype_atom14_to_atom37,
|
||||
dtype=torch.int32,
|
||||
device=protein["aatype"].device,
|
||||
)
|
||||
restype_atom37_to_atom14 = torch.tensor(
|
||||
restype_atom37_to_atom14,
|
||||
dtype=torch.int32,
|
||||
device=protein["aatype"].device,
|
||||
)
|
||||
restype_atom14_mask = torch.tensor(
|
||||
restype_atom14_mask,
|
||||
dtype=torch.float32,
|
||||
device=protein["aatype"].device,
|
||||
)
|
||||
protein_aatype = protein["aatype"].to(torch.long)
|
||||
|
||||
# create the mapping for (residx, atom14) --> atom37, i.e. an array
|
||||
# with shape (num_res, 14) containing the atom37 indices for this protein
|
||||
residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
|
||||
residx_atom14_mask = restype_atom14_mask[protein_aatype]
|
||||
|
||||
protein["atom14_atom_exists"] = residx_atom14_mask
|
||||
protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
|
||||
|
||||
# create the gather indices for mapping back
|
||||
residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
|
||||
protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
|
||||
|
||||
# create the corresponding mask
|
||||
restype_atom37_mask = torch.zeros([21, 37], dtype=torch.float32, device=protein["aatype"].device)
|
||||
for restype, restype_letter in enumerate(rc.restypes):
|
||||
restype_name = rc.restype_1to3[restype_letter]
|
||||
atom_names = rc.residue_atoms[restype_name]
|
||||
for atom_name in atom_names:
|
||||
atom_type = rc.atom_order[atom_name]
|
||||
restype_atom37_mask[restype, atom_type] = 1
|
||||
|
||||
residx_atom37_mask = restype_atom37_mask[protein_aatype]
|
||||
protein["atom37_atom_exists"] = residx_atom37_mask
|
||||
|
||||
return protein
|
||||
|
||||
|
||||
def make_atom14_masks_np(batch):
|
||||
batch = tree_map(lambda n: torch.tensor(n, device=batch["aatype"].device), batch, np.ndarray)
|
||||
out = make_atom14_masks(batch)
|
||||
out = tensor_tree_map(lambda t: np.array(t), out)
|
||||
return out
|
234
src/transformers/models/esm/openfold_utils/feats.py
Normal file
234
src/transformers/models/esm/openfold_utils/feats.py
Normal file
@ -0,0 +1,234 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from . import residue_constants as rc
|
||||
from .rigid_utils import Rigid, Rotation
|
||||
from .tensor_utils import batched_gather
|
||||
|
||||
|
||||
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
|
||||
is_gly = aatype == rc.restype_order["G"]
|
||||
ca_idx = rc.atom_order["CA"]
|
||||
cb_idx = rc.atom_order["CB"]
|
||||
pseudo_beta = torch.where(
|
||||
is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3),
|
||||
all_atom_positions[..., ca_idx, :],
|
||||
all_atom_positions[..., cb_idx, :],
|
||||
)
|
||||
|
||||
if all_atom_masks is not None:
|
||||
pseudo_beta_mask = torch.where(
|
||||
is_gly,
|
||||
all_atom_masks[..., ca_idx],
|
||||
all_atom_masks[..., cb_idx],
|
||||
)
|
||||
return pseudo_beta, pseudo_beta_mask
|
||||
else:
|
||||
return pseudo_beta
|
||||
|
||||
|
||||
def atom14_to_atom37(atom14, batch):
|
||||
atom37_data = batched_gather(
|
||||
atom14,
|
||||
batch["residx_atom37_to_atom14"],
|
||||
dim=-2,
|
||||
no_batch_dims=len(atom14.shape[:-2]),
|
||||
)
|
||||
|
||||
atom37_data = atom37_data * batch["atom37_atom_exists"][..., None]
|
||||
|
||||
return atom37_data
|
||||
|
||||
|
||||
def build_template_angle_feat(template_feats):
|
||||
template_aatype = template_feats["template_aatype"]
|
||||
torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
|
||||
alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"]
|
||||
torsion_angles_mask = template_feats["template_torsion_angles_mask"]
|
||||
template_angle_feat = torch.cat(
|
||||
[
|
||||
nn.functional.one_hot(template_aatype, 22),
|
||||
torsion_angles_sin_cos.reshape(*torsion_angles_sin_cos.shape[:-2], 14),
|
||||
alt_torsion_angles_sin_cos.reshape(*alt_torsion_angles_sin_cos.shape[:-2], 14),
|
||||
torsion_angles_mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
return template_angle_feat
|
||||
|
||||
|
||||
def build_template_pair_feat(batch, min_bin, max_bin, no_bins, use_unit_vector=False, eps=1e-20, inf=1e8):
|
||||
template_mask = batch["template_pseudo_beta_mask"]
|
||||
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
|
||||
|
||||
# Compute distogram (this seems to differ slightly from Alg. 5)
|
||||
tpb = batch["template_pseudo_beta"]
|
||||
dgram = torch.sum((tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True)
|
||||
lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2
|
||||
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
|
||||
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
|
||||
|
||||
to_concat = [dgram, template_mask_2d[..., None]]
|
||||
|
||||
aatype_one_hot = nn.functional.one_hot(
|
||||
batch["template_aatype"],
|
||||
rc.restype_num + 2,
|
||||
)
|
||||
|
||||
n_res = batch["template_aatype"].shape[-1]
|
||||
to_concat.append(aatype_one_hot[..., None, :, :].expand(*aatype_one_hot.shape[:-2], n_res, -1, -1))
|
||||
to_concat.append(aatype_one_hot[..., None, :].expand(*aatype_one_hot.shape[:-2], -1, n_res, -1))
|
||||
|
||||
n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
|
||||
rigids = Rigid.make_transform_from_reference(
|
||||
n_xyz=batch["template_all_atom_positions"][..., n, :],
|
||||
ca_xyz=batch["template_all_atom_positions"][..., ca, :],
|
||||
c_xyz=batch["template_all_atom_positions"][..., c, :],
|
||||
eps=eps,
|
||||
)
|
||||
points = rigids.get_trans()[..., None, :, :]
|
||||
rigid_vec = rigids[..., None].invert_apply(points)
|
||||
|
||||
inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1))
|
||||
|
||||
t_aa_masks = batch["template_all_atom_mask"]
|
||||
template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
|
||||
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
|
||||
|
||||
inv_distance_scalar = inv_distance_scalar * template_mask_2d
|
||||
unit_vector = rigid_vec * inv_distance_scalar[..., None]
|
||||
|
||||
if not use_unit_vector:
|
||||
unit_vector = unit_vector * 0.0
|
||||
|
||||
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
|
||||
to_concat.append(template_mask_2d[..., None])
|
||||
|
||||
act = torch.cat(to_concat, dim=-1)
|
||||
act = act * template_mask_2d[..., None]
|
||||
|
||||
return act
|
||||
|
||||
|
||||
def build_extra_msa_feat(batch):
|
||||
msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23)
|
||||
msa_feat = [
|
||||
msa_1hot,
|
||||
batch["extra_has_deletion"].unsqueeze(-1),
|
||||
batch["extra_deletion_value"].unsqueeze(-1),
|
||||
]
|
||||
return torch.cat(msa_feat, dim=-1)
|
||||
|
||||
|
||||
def torsion_angles_to_frames(
|
||||
r: Rigid,
|
||||
alpha: torch.Tensor,
|
||||
aatype: torch.Tensor,
|
||||
rrgdf: torch.Tensor,
|
||||
):
|
||||
# [*, N, 8, 4, 4]
|
||||
default_4x4 = rrgdf[aatype, ...]
|
||||
|
||||
# [*, N, 8] transformations, i.e.
|
||||
# One [*, N, 8, 3, 3] rotation matrix and
|
||||
# One [*, N, 8, 3] translation matrix
|
||||
default_r = r.from_tensor_4x4(default_4x4)
|
||||
|
||||
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
|
||||
bb_rot[..., 1] = 1
|
||||
|
||||
# [*, N, 8, 2]
|
||||
alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2)
|
||||
|
||||
# [*, N, 8, 3, 3]
|
||||
# Produces rotation matrices of the form:
|
||||
# [
|
||||
# [1, 0 , 0 ],
|
||||
# [0, a_2,-a_1],
|
||||
# [0, a_1, a_2]
|
||||
# ]
|
||||
# This follows the original code rather than the supplement, which uses
|
||||
# different indices.
|
||||
|
||||
all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
|
||||
all_rots[..., 0, 0] = 1
|
||||
all_rots[..., 1, 1] = alpha[..., 1]
|
||||
all_rots[..., 1, 2] = -alpha[..., 0]
|
||||
all_rots[..., 2, 1:] = alpha
|
||||
|
||||
all_rots = Rigid(Rotation(rot_mats=all_rots), None)
|
||||
|
||||
all_frames = default_r.compose(all_rots)
|
||||
|
||||
chi2_frame_to_frame = all_frames[..., 5]
|
||||
chi3_frame_to_frame = all_frames[..., 6]
|
||||
chi4_frame_to_frame = all_frames[..., 7]
|
||||
|
||||
chi1_frame_to_bb = all_frames[..., 4]
|
||||
chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
|
||||
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
|
||||
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
|
||||
|
||||
all_frames_to_bb = Rigid.cat(
|
||||
[
|
||||
all_frames[..., :5],
|
||||
chi2_frame_to_bb.unsqueeze(-1),
|
||||
chi3_frame_to_bb.unsqueeze(-1),
|
||||
chi4_frame_to_bb.unsqueeze(-1),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
all_frames_to_global = r[..., None].compose(all_frames_to_bb)
|
||||
|
||||
return all_frames_to_global
|
||||
|
||||
|
||||
def frames_and_literature_positions_to_atom14_pos(
|
||||
r: Rigid,
|
||||
aatype: torch.Tensor,
|
||||
default_frames,
|
||||
group_idx,
|
||||
atom_mask,
|
||||
lit_positions,
|
||||
):
|
||||
# [*, N, 14]
|
||||
group_mask = group_idx[aatype, ...]
|
||||
|
||||
# [*, N, 14, 8]
|
||||
group_mask = nn.functional.one_hot(
|
||||
group_mask,
|
||||
num_classes=default_frames.shape[-3],
|
||||
)
|
||||
|
||||
# [*, N, 14, 8]
|
||||
t_atoms_to_global = r[..., None, :] * group_mask
|
||||
|
||||
# [*, N, 14]
|
||||
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
|
||||
|
||||
# [*, N, 14, 1]
|
||||
atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
|
||||
|
||||
# [*, N, 14, 3]
|
||||
lit_positions = lit_positions[aatype, ...]
|
||||
pred_positions = t_atoms_to_global.apply(lit_positions)
|
||||
pred_positions = pred_positions * atom_mask
|
||||
|
||||
return pred_positions
|
105
src/transformers/models/esm/openfold_utils/loss.py
Normal file
105
src/transformers/models/esm/openfold_utils/loss.py
Normal file
@ -0,0 +1,105 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _calculate_bin_centers(boundaries: torch.Tensor):
|
||||
step = boundaries[1] - boundaries[0]
|
||||
bin_centers = boundaries + step / 2
|
||||
bin_centers = torch.cat([bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0)
|
||||
return bin_centers
|
||||
|
||||
|
||||
def _calculate_expected_aligned_error(
|
||||
alignment_confidence_breaks: torch.Tensor,
|
||||
aligned_distance_error_probs: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
|
||||
return (
|
||||
torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
|
||||
bin_centers[-1],
|
||||
)
|
||||
|
||||
|
||||
def compute_predicted_aligned_error(
|
||||
logits: torch.Tensor,
|
||||
max_bin: int = 31,
|
||||
no_bins: int = 64,
|
||||
**kwargs,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Computes aligned confidence metrics from logits.
|
||||
|
||||
Args:
|
||||
logits: [*, num_res, num_res, num_bins] the logits output from
|
||||
PredictedAlignedErrorHead.
|
||||
max_bin: Maximum bin value
|
||||
no_bins: Number of bins
|
||||
Returns:
|
||||
aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted
|
||||
aligned error probabilities over bins for each residue pair.
|
||||
predicted_aligned_error: [*, num_res, num_res] the expected aligned distance
|
||||
error for each pair of residues.
|
||||
max_predicted_aligned_error: [*] the maximum predicted error possible.
|
||||
"""
|
||||
boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device)
|
||||
|
||||
aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
(predicted_aligned_error, max_predicted_aligned_error,) = _calculate_expected_aligned_error(
|
||||
alignment_confidence_breaks=boundaries,
|
||||
aligned_distance_error_probs=aligned_confidence_probs,
|
||||
)
|
||||
|
||||
return {
|
||||
"aligned_confidence_probs": aligned_confidence_probs,
|
||||
"predicted_aligned_error": predicted_aligned_error,
|
||||
"max_predicted_aligned_error": max_predicted_aligned_error,
|
||||
}
|
||||
|
||||
|
||||
def compute_tm(
|
||||
logits: torch.Tensor,
|
||||
residue_weights: Optional[torch.Tensor] = None,
|
||||
max_bin: int = 31,
|
||||
no_bins: int = 64,
|
||||
eps: float = 1e-8,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if residue_weights is None:
|
||||
residue_weights = logits.new_ones(logits.shape[-2])
|
||||
|
||||
boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device)
|
||||
|
||||
bin_centers = _calculate_bin_centers(boundaries)
|
||||
torch.sum(residue_weights)
|
||||
n = logits.shape[-2]
|
||||
clipped_n = max(n, 19)
|
||||
|
||||
d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
|
||||
tm_per_bin = 1.0 / (1 + (bin_centers**2) / (d0**2))
|
||||
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
|
||||
|
||||
normed_residue_mask = residue_weights / (eps + residue_weights.sum())
|
||||
per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
|
||||
|
||||
weighted = per_alignment * residue_weights
|
||||
|
||||
argmax = (weighted == torch.max(weighted)).nonzero()[0]
|
||||
return per_alignment[tuple(argmax)]
|
329
src/transformers/models/esm/openfold_utils/protein.py
Normal file
329
src/transformers/models/esm/openfold_utils/protein.py
Normal file
@ -0,0 +1,329 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Protein data type."""
|
||||
import dataclasses
|
||||
import re
|
||||
import string
|
||||
from typing import Any, Mapping, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import residue_constants
|
||||
|
||||
|
||||
FeatureDict = Mapping[str, np.ndarray]
|
||||
ModelOutput = Mapping[str, Any] # Is a nested dict.
|
||||
PICO_TO_ANGSTROM = 0.01
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Protein:
|
||||
"""Protein structure representation."""
|
||||
|
||||
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
|
||||
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
|
||||
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
|
||||
|
||||
# Amino-acid type for each residue represented as an integer between 0 and
|
||||
# 20, where 20 is 'X'.
|
||||
aatype: np.ndarray # [num_res]
|
||||
|
||||
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
|
||||
# is present and 0.0 if not. This should be used for loss masking.
|
||||
atom_mask: np.ndarray # [num_res, num_atom_type]
|
||||
|
||||
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
|
||||
residue_index: np.ndarray # [num_res]
|
||||
|
||||
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
|
||||
# representing the displacement of the residue from its ground truth mean
|
||||
# value.
|
||||
b_factors: np.ndarray # [num_res, num_atom_type]
|
||||
|
||||
# Chain indices for multi-chain predictions
|
||||
chain_index: Optional[np.ndarray] = None
|
||||
|
||||
# Optional remark about the protein. Included as a comment in output PDB
|
||||
# files
|
||||
remark: Optional[str] = None
|
||||
|
||||
# Templates used to generate this protein (prediction-only)
|
||||
parents: Optional[Sequence[str]] = None
|
||||
|
||||
# Chain corresponding to each parent
|
||||
parents_chain_index: Optional[Sequence[int]] = None
|
||||
|
||||
|
||||
def from_proteinnet_string(proteinnet_str: str) -> Protein:
|
||||
tag_re = r"(\[[A-Z]+\]\n)"
|
||||
tags = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0]
|
||||
groups = zip(tags[0::2], [l.split("\n") for l in tags[1::2]])
|
||||
|
||||
atoms = ["N", "CA", "C"]
|
||||
aatype = None
|
||||
atom_positions = None
|
||||
atom_mask = None
|
||||
for g in groups:
|
||||
if "[PRIMARY]" == g[0]:
|
||||
seq = g[1][0].strip()
|
||||
for i in range(len(seq)):
|
||||
if seq[i] not in residue_constants.restypes:
|
||||
seq[i] = "X"
|
||||
aatype = np.array(
|
||||
[residue_constants.restype_order.get(res_symbol, residue_constants.restype_num) for res_symbol in seq]
|
||||
)
|
||||
elif "[TERTIARY]" == g[0]:
|
||||
tertiary = []
|
||||
for axis in range(3):
|
||||
tertiary.append(list(map(float, g[1][axis].split())))
|
||||
tertiary_np = np.array(tertiary)
|
||||
atom_positions = np.zeros((len(tertiary[0]) // 3, residue_constants.atom_type_num, 3)).astype(np.float32)
|
||||
for i, atom in enumerate(atoms):
|
||||
atom_positions[:, residue_constants.atom_order[atom], :] = np.transpose(tertiary_np[:, i::3])
|
||||
atom_positions *= PICO_TO_ANGSTROM
|
||||
elif "[MASK]" == g[0]:
|
||||
mask = np.array(list(map({"-": 0, "+": 1}.get, g[1][0].strip())))
|
||||
atom_mask = np.zeros(
|
||||
(
|
||||
len(mask),
|
||||
residue_constants.atom_type_num,
|
||||
)
|
||||
).astype(np.float32)
|
||||
for i, atom in enumerate(atoms):
|
||||
atom_mask[:, residue_constants.atom_order[atom]] = 1
|
||||
atom_mask *= mask[..., None]
|
||||
|
||||
return Protein(
|
||||
atom_positions=atom_positions,
|
||||
atom_mask=atom_mask,
|
||||
aatype=aatype,
|
||||
residue_index=np.arange(len(aatype)),
|
||||
b_factors=None,
|
||||
)
|
||||
|
||||
|
||||
def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
|
||||
pdb_headers = []
|
||||
|
||||
remark = prot.remark
|
||||
if remark is not None:
|
||||
pdb_headers.append(f"REMARK {remark}")
|
||||
|
||||
parents = prot.parents
|
||||
parents_chain_index = prot.parents_chain_index
|
||||
if parents_chain_index is not None:
|
||||
parents = [p for i, p in zip(parents_chain_index, parents) if i == chain_id]
|
||||
|
||||
if parents is None or len(parents) == 0:
|
||||
parents = ["N/A"]
|
||||
|
||||
pdb_headers.append(f"PARENT {' '.join(parents)}")
|
||||
|
||||
return pdb_headers
|
||||
|
||||
|
||||
def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
|
||||
"""Add pdb headers to an existing PDB string. Useful during multi-chain
|
||||
recycling
|
||||
"""
|
||||
out_pdb_lines = []
|
||||
lines = pdb_str.split("\n")
|
||||
|
||||
remark = prot.remark
|
||||
if remark is not None:
|
||||
out_pdb_lines.append(f"REMARK {remark}")
|
||||
|
||||
parents_per_chain = None
|
||||
if prot.parents is not None and len(prot.parents) > 0:
|
||||
parents_per_chain = []
|
||||
if prot.parents_chain_index is not None:
|
||||
parent_dict = {}
|
||||
for p, i in zip(prot.parents, prot.parents_chain_index):
|
||||
parent_dict.setdefault(str(i), [])
|
||||
parent_dict[str(i)].append(p)
|
||||
|
||||
max_idx = max([int(chain_idx) for chain_idx in parent_dict])
|
||||
for i in range(max_idx + 1):
|
||||
chain_parents = parent_dict.get(str(i), ["N/A"])
|
||||
parents_per_chain.append(chain_parents)
|
||||
else:
|
||||
parents_per_chain.append(prot.parents)
|
||||
else:
|
||||
parents_per_chain = [["N/A"]]
|
||||
|
||||
def make_parent_line(p):
|
||||
return f"PARENT {' '.join(p)}"
|
||||
|
||||
out_pdb_lines.append(make_parent_line(parents_per_chain[0]))
|
||||
|
||||
chain_counter = 0
|
||||
for i, l in enumerate(lines):
|
||||
if "PARENT" not in l and "REMARK" not in l:
|
||||
out_pdb_lines.append(l)
|
||||
if "TER" in l and "END" not in lines[i + 1]:
|
||||
chain_counter += 1
|
||||
if not chain_counter >= len(parents_per_chain):
|
||||
chain_parents = parents_per_chain[chain_counter]
|
||||
else:
|
||||
chain_parents = ["N/A"]
|
||||
|
||||
out_pdb_lines.append(make_parent_line(chain_parents))
|
||||
|
||||
return "\n".join(out_pdb_lines)
|
||||
|
||||
|
||||
def to_pdb(prot: Protein) -> str:
|
||||
"""Converts a `Protein` instance to a PDB string.
|
||||
|
||||
Args:
|
||||
prot: The protein to convert to PDB.
|
||||
|
||||
Returns:
|
||||
PDB string.
|
||||
"""
|
||||
restypes = residue_constants.restypes + ["X"]
|
||||
|
||||
def res_1to3(r):
|
||||
return residue_constants.restype_1to3.get(restypes[r], "UNK")
|
||||
|
||||
atom_types = residue_constants.atom_types
|
||||
|
||||
pdb_lines = []
|
||||
|
||||
atom_mask = prot.atom_mask
|
||||
aatype = prot.aatype
|
||||
atom_positions = prot.atom_positions
|
||||
residue_index = prot.residue_index.astype(np.int32)
|
||||
b_factors = prot.b_factors
|
||||
chain_index = prot.chain_index
|
||||
|
||||
if np.any(aatype > residue_constants.restype_num):
|
||||
raise ValueError("Invalid aatypes.")
|
||||
|
||||
headers = get_pdb_headers(prot)
|
||||
if len(headers) > 0:
|
||||
pdb_lines.extend(headers)
|
||||
|
||||
n = aatype.shape[0]
|
||||
atom_index = 1
|
||||
prev_chain_index = 0
|
||||
chain_tags = string.ascii_uppercase
|
||||
# Add all atom sites.
|
||||
for i in range(n):
|
||||
res_name_3 = res_1to3(aatype[i])
|
||||
for atom_name, pos, mask, b_factor in zip(atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
|
||||
if mask < 0.5:
|
||||
continue
|
||||
|
||||
record_type = "ATOM"
|
||||
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
|
||||
alt_loc = ""
|
||||
insertion_code = ""
|
||||
occupancy = 1.00
|
||||
element = atom_name[0] # Protein supports only C, N, O, S, this works.
|
||||
charge = ""
|
||||
|
||||
chain_tag = "A"
|
||||
if chain_index is not None:
|
||||
chain_tag = chain_tags[chain_index[i]]
|
||||
|
||||
# PDB is a columnar format, every space matters here!
|
||||
atom_line = (
|
||||
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
|
||||
f"{res_name_3:>3} {chain_tag:>1}"
|
||||
f"{residue_index[i]:>4}{insertion_code:>1} "
|
||||
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
|
||||
f"{occupancy:>6.2f}{b_factor:>6.2f} "
|
||||
f"{element:>2}{charge:>2}"
|
||||
)
|
||||
pdb_lines.append(atom_line)
|
||||
atom_index += 1
|
||||
|
||||
should_terminate = i == n - 1
|
||||
if chain_index is not None:
|
||||
if i != n - 1 and chain_index[i + 1] != prev_chain_index:
|
||||
should_terminate = True
|
||||
prev_chain_index = chain_index[i + 1]
|
||||
|
||||
if should_terminate:
|
||||
# Close the chain.
|
||||
chain_end = "TER"
|
||||
chain_termination_line = (
|
||||
f"{chain_end:<6}{atom_index:>5} {res_1to3(aatype[i]):>3} {chain_tag:>1}{residue_index[i]:>4}"
|
||||
)
|
||||
pdb_lines.append(chain_termination_line)
|
||||
atom_index += 1
|
||||
|
||||
if i != n - 1:
|
||||
# "prev" is a misnomer here. This happens at the beginning of
|
||||
# each new chain.
|
||||
pdb_lines.extend(get_pdb_headers(prot, prev_chain_index))
|
||||
|
||||
pdb_lines.append("END")
|
||||
pdb_lines.append("")
|
||||
return "\n".join(pdb_lines)
|
||||
|
||||
|
||||
def ideal_atom_mask(prot: Protein) -> np.ndarray:
|
||||
"""Computes an ideal atom mask.
|
||||
|
||||
`Protein.atom_mask` typically is defined according to the atoms that are reported in the PDB. This function
|
||||
computes a mask according to heavy atoms that should be present in the given sequence of amino acids.
|
||||
|
||||
Args:
|
||||
prot: `Protein` whose fields are `numpy.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
An ideal atom mask.
|
||||
"""
|
||||
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
|
||||
|
||||
|
||||
def from_prediction(
|
||||
features: FeatureDict,
|
||||
result: ModelOutput,
|
||||
b_factors: Optional[np.ndarray] = None,
|
||||
chain_index: Optional[np.ndarray] = None,
|
||||
remark: Optional[str] = None,
|
||||
parents: Optional[Sequence[str]] = None,
|
||||
parents_chain_index: Optional[Sequence[int]] = None,
|
||||
) -> Protein:
|
||||
"""Assembles a protein from a prediction.
|
||||
|
||||
Args:
|
||||
features: Dictionary holding model inputs.
|
||||
result: Dictionary holding model outputs.
|
||||
b_factors: (Optional) B-factors to use for the protein.
|
||||
chain_index: (Optional) Chain indices for multi-chain predictions
|
||||
remark: (Optional) Remark about the prediction
|
||||
parents: (Optional) List of template names
|
||||
Returns:
|
||||
A protein instance.
|
||||
"""
|
||||
if b_factors is None:
|
||||
b_factors = np.zeros_like(result["final_atom_mask"])
|
||||
|
||||
return Protein(
|
||||
aatype=features["aatype"],
|
||||
atom_positions=result["final_atom_positions"],
|
||||
atom_mask=result["final_atom_mask"],
|
||||
residue_index=features["residue_index"] + 1,
|
||||
b_factors=b_factors,
|
||||
chain_index=chain_index,
|
||||
remark=remark,
|
||||
parents=parents,
|
||||
parents_chain_index=parents_chain_index,
|
||||
)
|
1251
src/transformers/models/esm/openfold_utils/residue_constants.py
Normal file
1251
src/transformers/models/esm/openfold_utils/residue_constants.py
Normal file
File diff suppressed because it is too large
Load Diff
1290
src/transformers/models/esm/openfold_utils/rigid_utils.py
Normal file
1290
src/transformers/models/esm/openfold_utils/rigid_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
116
src/transformers/models/esm/openfold_utils/tensor_utils.py
Normal file
116
src/transformers/models/esm/openfold_utils/tensor_utils.py
Normal file
@ -0,0 +1,116 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def add(m1, m2, inplace):
|
||||
# The first operation in a checkpoint can't be in-place, but it's
|
||||
# nice to have in-place addition during inference. Thus...
|
||||
if not inplace:
|
||||
m1 = m1 + m2
|
||||
else:
|
||||
m1 += m2
|
||||
|
||||
return m1
|
||||
|
||||
|
||||
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
|
||||
zero_index = -1 * len(inds)
|
||||
first_inds = list(range(len(tensor.shape[:zero_index])))
|
||||
return tensor.permute(first_inds + [zero_index + i for i in inds])
|
||||
|
||||
|
||||
def flatten_final_dims(t: torch.Tensor, no_dims: int):
|
||||
return t.reshape(t.shape[:-no_dims] + (-1,))
|
||||
|
||||
|
||||
def masked_mean(mask, value, dim, eps=1e-4):
|
||||
mask = mask.expand(*value.shape)
|
||||
return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
|
||||
|
||||
|
||||
def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
|
||||
boundaries = torch.linspace(min_bin, max_bin, no_bins - 1, device=pts.device)
|
||||
dists = torch.sqrt(torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1))
|
||||
return torch.bucketize(dists, boundaries)
|
||||
|
||||
|
||||
def dict_multimap(fn, dicts):
|
||||
first = dicts[0]
|
||||
new_dict = {}
|
||||
for k, v in first.items():
|
||||
all_v = [d[k] for d in dicts]
|
||||
if type(v) is dict:
|
||||
new_dict[k] = dict_multimap(fn, all_v)
|
||||
else:
|
||||
new_dict[k] = fn(all_v)
|
||||
|
||||
return new_dict
|
||||
|
||||
|
||||
def one_hot(x, v_bins):
|
||||
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
|
||||
diffs = x[..., None] - reshaped_bins
|
||||
am = torch.argmin(torch.abs(diffs), dim=-1)
|
||||
return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
|
||||
|
||||
|
||||
def batched_gather(data, inds, dim=0, no_batch_dims=0):
|
||||
ranges = []
|
||||
for i, s in enumerate(data.shape[:no_batch_dims]):
|
||||
r = torch.arange(s)
|
||||
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
|
||||
ranges.append(r)
|
||||
|
||||
remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
|
||||
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
|
||||
ranges.extend(remaining_dims)
|
||||
# Matt note: Editing this to get around the behaviour of using a list as an array index changing
|
||||
# in recent Numpy versions
|
||||
return data[tuple(ranges)]
|
||||
|
||||
|
||||
# With tree_map, a poor man's JAX tree_map
|
||||
def dict_map(fn, dic, leaf_type):
|
||||
new_dict = {}
|
||||
for k, v in dic.items():
|
||||
if type(v) is dict:
|
||||
new_dict[k] = dict_map(fn, v, leaf_type)
|
||||
else:
|
||||
new_dict[k] = tree_map(fn, v, leaf_type)
|
||||
|
||||
return new_dict
|
||||
|
||||
|
||||
def tree_map(fn, tree, leaf_type):
|
||||
if isinstance(tree, dict):
|
||||
return dict_map(fn, tree, leaf_type)
|
||||
elif isinstance(tree, list):
|
||||
return [tree_map(fn, x, leaf_type) for x in tree]
|
||||
elif isinstance(tree, tuple):
|
||||
return tuple([tree_map(fn, x, leaf_type) for x in tree])
|
||||
elif isinstance(tree, leaf_type):
|
||||
return fn(tree)
|
||||
else:
|
||||
print(type(tree))
|
||||
raise ValueError("Not supported")
|
||||
|
||||
|
||||
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
|
@ -1985,6 +1985,13 @@ class ErniePreTrainedModel(metaclass=DummyObject):
|
||||
ESM_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class EsmFoldPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class EsmForMaskedLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@ -1992,6 +1999,13 @@ class EsmForMaskedLM(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class EsmForProteinFolding(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class EsmForSequenceClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -20,7 +20,6 @@ import unittest
|
||||
from transformers import EsmConfig, is_torch_available
|
||||
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
@ -49,7 +48,7 @@ class EsmModelTester:
|
||||
self.use_input_mask = True
|
||||
self.use_token_type_ids = False
|
||||
self.use_labels = True
|
||||
self.vocab_size = 99
|
||||
self.vocab_size = 33
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_attention_heads = 4
|
||||
@ -145,7 +144,7 @@ class EsmModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class EsmModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
class EsmModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
test_mismatched_shapes = False
|
||||
|
||||
@ -253,28 +252,32 @@ class EsmModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
class EsmModelIntegrationTest(TestCasePlus):
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
output = model(input_ids)[0]
|
||||
with torch.no_grad():
|
||||
model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||
model.eval()
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
output = model(input_ids)[0]
|
||||
|
||||
vocab_size = 33
|
||||
vocab_size = 33
|
||||
|
||||
expected_shape = torch.Size((1, 6, vocab_size))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_shape = torch.Size((1, 6, vocab_size))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[[15.0973, -6.6406, -1.1351], [-0.2209, -9.9622, 4.2109], [-1.6055, -10.0023, 1.5914]]]
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
expected_slice = torch.tensor(
|
||||
[[[15.0973, -6.6406, -1.1351], [-0.2209, -9.9622, 4.2109], [-1.6055, -10.0023, 1.5914]]]
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
model = EsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||
with torch.no_grad():
|
||||
model = EsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||
model.eval()
|
||||
|
||||
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
|
||||
output = model(input_ids)[0]
|
||||
# compare the actual values for a slice.
|
||||
expected_slice = torch.tensor(
|
||||
[[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]]
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
|
||||
output = model(input_ids)[0]
|
||||
# compare the actual values for a slice.
|
||||
expected_slice = torch.tensor(
|
||||
[[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]]
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
234
tests/models/esm/test_modeling_esmfold.py
Normal file
234
tests/models/esm/test_modeling_esmfold.py
Normal file
@ -0,0 +1,234 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch ESM model. """
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import EsmConfig, is_torch_available
|
||||
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.models.esm.modeling_esmfold import EsmForProteinFolding
|
||||
|
||||
|
||||
class EsmFoldModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = 13
|
||||
self.seq_length = 7
|
||||
self.is_training = False
|
||||
self.use_input_mask = True
|
||||
self.use_token_type_ids = False
|
||||
self.use_labels = False
|
||||
self.vocab_size = 19
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
self.hidden_dropout_prob = 0.1
|
||||
self.attention_probs_dropout_prob = 0.1
|
||||
self.max_position_embeddings = 512
|
||||
self.type_vocab_size = 16
|
||||
self.type_sequence_label_size = 2
|
||||
self.initializer_range = 0.02
|
||||
self.num_labels = 3
|
||||
self.num_choices = 4
|
||||
self.scope = None
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def get_config(self):
|
||||
config = EsmConfig(
|
||||
vocab_size=33,
|
||||
hidden_size=self.hidden_size,
|
||||
pad_token_id=1,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range,
|
||||
is_folding_model=True,
|
||||
esmfold_config={"trunk": {"num_blocks": 2}, "fp16_esm": False},
|
||||
)
|
||||
return config
|
||||
|
||||
def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = EsmForProteinFolding(config=config).float()
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertEqual(result.positions.shape, (8, self.batch_size, self.seq_length, 14, 3))
|
||||
self.parent.assertEqual(result.angles.shape, (8, self.batch_size, self.seq_length, 7, 2))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class EsmFoldModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
test_mismatched_shapes = False
|
||||
|
||||
all_model_classes = (EsmForProteinFolding,) if is_torch_available() else ()
|
||||
all_generative_model_classes = ()
|
||||
test_sequence_classification_problem_types = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = EsmFoldModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=EsmConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip("Does not support attention outputs")
|
||||
def test_attention_outputs(self):
|
||||
pass
|
||||
|
||||
@unittest.skip
|
||||
def test_correct_missing_keys(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Esm does not support embedding resizing")
|
||||
def test_resize_embeddings_untied(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Esm does not support embedding resizing")
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ESMFold does not support passing input embeds!")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ESMFold does not support head pruning.")
|
||||
def test_head_pruning(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ESMFold does not support head pruning.")
|
||||
def test_head_pruning_integration(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ESMFold does not support head pruning.")
|
||||
def test_head_pruning_save_load_from_config_init(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ESMFold does not support head pruning.")
|
||||
def test_head_pruning_save_load_from_pretrained(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ESMFold does not support head pruning.")
|
||||
def test_headmasking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ESMFold does not output hidden states in the normal way.")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ESMfold does not output hidden states in the normal way.")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ESMFold only has one output format.")
|
||||
def test_model_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("This test doesn't work for ESMFold and doesn't test core functionality")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ESMFold does not support input chunking.")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ESMFold doesn't respect you and it certainly doesn't respect your initialization arguments.")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ESMFold doesn't support torchscript compilation.")
|
||||
def test_torchscript_output_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ESMFold doesn't support torchscript compilation.")
|
||||
def test_torchscript_output_hidden_state(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ESMFold doesn't support torchscript compilation.")
|
||||
def test_torchscript_simple(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ESMFold doesn't support data parallel.")
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class EsmModelIntegrationTest(TestCasePlus):
|
||||
@slow
|
||||
def test_inference_protein_folding(self):
|
||||
model = EsmForProteinFolding.from_pretrained("Rocketknight1/esmfold_v1").float()
|
||||
model.eval()
|
||||
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
|
||||
position_outputs = model(input_ids)["positions"]
|
||||
expected_slice = torch.tensor([2.5828, 0.7993, -10.9334], dtype=torch.float32)
|
||||
self.assertTrue(torch.allclose(position_outputs[0, 0, 0, 0], expected_slice, atol=1e-4))
|
@ -254,7 +254,7 @@ class TFEsmModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@require_tf
|
||||
class TFEsmModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
@unittest.skip("Temporarily disabled as we update ESM model checkpoints")
|
||||
def test_inference_masked_lm(self):
|
||||
model = TFEsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||
|
||||
@ -268,7 +268,7 @@ class TFEsmModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4))
|
||||
|
||||
@slow
|
||||
@unittest.skip("Temporarily disabled as we update ESM model checkpoints")
|
||||
def test_inference_no_head(self):
|
||||
model = TFEsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||
|
||||
|
@ -268,6 +268,7 @@ def get_transformers_submodules():
|
||||
IGNORE_SUBMODULES = [
|
||||
"convert_pytorch_checkpoint_to_tf2",
|
||||
"modeling_flax_pytorch_utils",
|
||||
"models.esm.openfold_utils",
|
||||
]
|
||||
|
||||
|
||||
|
@ -140,6 +140,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
|
||||
# should **not** be the rule.
|
||||
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
# models to ignore for model xxx mapping
|
||||
"EsmForProteinFolding",
|
||||
"TimeSeriesTransformerForPrediction",
|
||||
"PegasusXEncoder",
|
||||
"PegasusXDecoder",
|
||||
|
Loading…
Reference in New Issue
Block a user