ESM openfold_utils type hints (#20544)

* add type annotations for esm chunk_utils

use isinstance builtin instead of 'type(x) is y'; add assertions to aid in type inferencing; use bools instead of ints in _get_minimal_slice_set for improved type clarity; refactor to avoid re-assigning to the same variable with a different type

* add type annotations for esm data_transforms

refactor to avoid re-assigning to the same variable with a different type

* add type annotations for esm feats utils

refactor to avoid re-assigning to the same variable with a different type

* add type annotations for esm loss utils

* add/fix type annotations for esm rigit_utils

refactor to avoid re-assigning to the same variable with a different type; fix Callable, Tuple type hints; match conditional structure to other methods; fix return type on Rotation.cat and Rotation.unsqueeze

* add type annotations for esm tensor_utils

overload for tree_map; use insinstance builtin instead of 'type(x) is y'; export dict_multimap, flatten_final_dims, permute_final_dims in openfold_utils

* add type annotations for esm protein utils

add FIXME for attempted string mutation; add missing None check in get_pdb_headers; fix potentially unbound variable 'chain_tag' in to_pdb; modify get_pdb_headers return type

* add type annotations for esm residue constants

hints on collection constants; remove magic trailing comma to reduce number of lines; change list -> tuple for rigid_group_atom_positions for improved hinting

* code style fixup

Co-authored-by: Matt <rocketknight1@gmail.com>
This commit is contained in:
Matthew Hoffman 2022-12-05 10:23:15 -06:00 committed by GitHub
parent 8ea6694d92
commit afe2a466bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 491 additions and 760 deletions

View File

@ -6,3 +6,4 @@ from .loss import compute_predicted_aligned_error, compute_tm
from .protein import Protein as OFProtein
from .protein import to_pdb
from .rigid_utils import Rigid, Rotation
from .tensor_utils import dict_multimap, flatten_final_dims, permute_final_dims

View File

@ -14,23 +14,22 @@
import logging
import math
from functools import partial
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import torch
from .tensor_utils import tensor_tree_map, tree_map
def _fetch_dims(tree):
def _fetch_dims(tree: Union[dict, list, tuple, torch.Tensor]) -> List[Tuple[int, ...]]:
shapes = []
tree_type = type(tree)
if tree_type is dict:
if isinstance(tree, dict):
for v in tree.values():
shapes.extend(_fetch_dims(v))
elif tree_type is list or tree_type is tuple:
elif isinstance(tree, (list, tuple)):
for t in tree:
shapes.extend(_fetch_dims(t))
elif tree_type is torch.Tensor:
elif isinstance(tree, torch.Tensor):
shapes.append(tree.shape)
else:
raise ValueError("Not supported")
@ -39,10 +38,7 @@ def _fetch_dims(tree):
@torch.jit.ignore
def _flat_idx_to_idx(
flat_idx: int,
dims: Tuple[int],
) -> Tuple[int]:
def _flat_idx_to_idx(flat_idx: int, dims: Tuple[int, ...]) -> Tuple[int, ...]:
idx = []
for d in reversed(dims):
idx.append(flat_idx % d)
@ -55,10 +51,10 @@ def _flat_idx_to_idx(
def _get_minimal_slice_set(
start: Sequence[int],
end: Sequence[int],
dims: int,
dims: Sequence[int],
start_edges: Optional[Sequence[bool]] = None,
end_edges: Optional[Sequence[bool]] = None,
) -> Sequence[Tuple[int]]:
) -> List[Tuple[slice, ...]]:
"""
Produces an ordered sequence of tensor slices that, when used in sequence on a tensor with shape dims, yields
tensors that contain every leaf in the contiguous range [start, end]. Care is taken to yield a short sequence of
@ -69,11 +65,11 @@ def _get_minimal_slice_set(
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def reduce_edge_list(l):
tally = 1
def reduce_edge_list(l: List[bool]) -> None:
tally = True
for i in range(len(l)):
reversed_idx = -1 * (i + 1)
l[reversed_idx] *= tally
l[reversed_idx] &= tally
tally = l[reversed_idx]
if start_edges is None:
@ -90,48 +86,54 @@ def _get_minimal_slice_set(
elif len(start) == 1:
return [(slice(start[0], end[0] + 1),)]
slices = []
path = []
slices: List[Tuple[slice, ...]] = []
path_list: List[slice] = []
# Dimensions common to start and end can be selected directly
for s, e in zip(start, end):
if s == e:
path.append(slice(s, s + 1))
path_list.append(slice(s, s + 1))
else:
break
path = tuple(path)
path: Tuple[slice, ...] = tuple(path_list)
divergence_idx = len(path)
# start == end, and we're done
if divergence_idx == len(dims):
return [tuple(path)]
return [path]
def upper() -> Tuple[Tuple[slice, ...], ...]:
assert start_edges is not None
assert end_edges is not None
def upper():
sdi = start[divergence_idx]
return [
return tuple(
path + (slice(sdi, sdi + 1),) + s
for s in _get_minimal_slice_set(
start[divergence_idx + 1 :],
[d - 1 for d in dims[divergence_idx + 1 :]],
dims[divergence_idx + 1 :],
start_edges=start_edges[divergence_idx + 1 :],
end_edges=[1 for _ in end_edges[divergence_idx + 1 :]],
end_edges=[True for _ in end_edges[divergence_idx + 1 :]],
)
]
)
def lower() -> Tuple[Tuple[slice, ...], ...]:
assert start_edges is not None
assert end_edges is not None
def lower():
edi = end[divergence_idx]
return [
return tuple(
path + (slice(edi, edi + 1),) + s
for s in _get_minimal_slice_set(
[0 for _ in start[divergence_idx + 1 :]],
end[divergence_idx + 1 :],
dims[divergence_idx + 1 :],
start_edges=[1 for _ in start_edges[divergence_idx + 1 :]],
start_edges=[True for _ in start_edges[divergence_idx + 1 :]],
end_edges=end_edges[divergence_idx + 1 :],
)
]
)
# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
@ -156,16 +158,11 @@ def _get_minimal_slice_set(
slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx]),))
slices.extend(lower())
return [tuple(s) for s in slices]
return slices
@torch.jit.ignore
def _chunk_slice(
t: torch.Tensor,
flat_start: int,
flat_end: int,
no_batch_dims: int,
) -> torch.Tensor:
def _chunk_slice(t: torch.Tensor, flat_start: int, flat_end: int, no_batch_dims: int) -> torch.Tensor:
"""
Equivalent to
@ -232,7 +229,7 @@ def chunk_layer(
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
def _prep_inputs(t):
def _prep_inputs(t: torch.Tensor) -> torch.Tensor:
if not low_mem:
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
@ -241,7 +238,7 @@ def chunk_layer(
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
return t
prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
prepped_inputs: Dict[str, Any] = tensor_tree_map(_prep_inputs, inputs)
prepped_outputs = None
if _out is not None:
prepped_outputs = tensor_tree_map(lambda t: t.view([-1] + list(t.shape[no_batch_dims:])), _out)
@ -252,7 +249,7 @@ def chunk_layer(
no_chunks = flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0)
def _select_chunk(t):
def _select_chunk(t: torch.Tensor) -> torch.Tensor:
return t[i : i + chunk_size] if t.shape[0] != 1 else t
i = 0
@ -269,7 +266,7 @@ def chunk_layer(
no_batch_dims=len(orig_batch_dims),
)
chunks = tensor_tree_map(select_chunk, prepped_inputs)
chunks: Dict[str, Any] = tensor_tree_map(select_chunk, prepped_inputs)
# Run the layer on the chunk
output_chunk = layer(**chunks)
@ -279,12 +276,11 @@ def chunk_layer(
out = tensor_tree_map(lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]), output_chunk)
# Put the chunk in its pre-allocated space
out_type = type(output_chunk)
if out_type is dict:
if isinstance(output_chunk, dict):
def assign(d1, d2):
def assign(d1: dict, d2: dict) -> None:
for k, v in d1.items():
if type(v) is dict:
if isinstance(v, dict):
assign(v, d2[k])
else:
if _add_into_out:
@ -293,13 +289,13 @@ def chunk_layer(
v[i : i + chunk_size] = d2[k]
assign(out, output_chunk)
elif out_type is tuple:
elif isinstance(output_chunk, tuple):
for x1, x2 in zip(out, output_chunk):
if _add_into_out:
x1[i : i + chunk_size] += x2
else:
x1[i : i + chunk_size] = x2
elif out_type is torch.Tensor:
elif isinstance(output_chunk, torch.Tensor):
if _add_into_out:
out[i : i + chunk_size] += output_chunk
else:
@ -319,24 +315,24 @@ class ChunkSizeTuner:
self,
# Heuristically, runtimes for most of the modules in the network
# plateau earlier than this on all GPUs I've run the model on.
max_chunk_size=512,
max_chunk_size: int = 512,
):
self.max_chunk_size = max_chunk_size
self.cached_chunk_size = None
self.cached_arg_data = None
self.cached_chunk_size: Optional[int] = None
self.cached_arg_data: Optional[tuple] = None
def _determine_favorable_chunk_size(self, fn, args, min_chunk_size):
def _determine_favorable_chunk_size(self, fn: Callable, args: tuple, min_chunk_size: int) -> int:
logging.info("Tuning chunk size...")
if min_chunk_size >= self.max_chunk_size:
return min_chunk_size
candidates = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
candidates: List[int] = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
candidates = [c for c in candidates if c > min_chunk_size]
candidates = [min_chunk_size] + candidates
candidates[-1] += 4
def test_chunk_size(chunk_size):
def test_chunk_size(chunk_size: int) -> bool:
try:
with torch.no_grad():
fn(*args, chunk_size=chunk_size)
@ -356,13 +352,13 @@ class ChunkSizeTuner:
return candidates[min_viable_chunk_size_index]
def _compare_arg_caches(self, ac1, ac2):
def _compare_arg_caches(self, ac1: Iterable, ac2: Iterable) -> bool:
consistent = True
for a1, a2 in zip(ac1, ac2):
assert type(ac1) == type(ac2)
if type(ac1) is list or type(ac1) is tuple:
if isinstance(ac1, (list, tuple)):
consistent &= self._compare_arg_caches(a1, a2)
elif type(ac1) is dict:
elif isinstance(ac1, dict):
a1_items = [v for _, v in sorted(a1.items(), key=lambda x: x[0])]
a2_items = [v for _, v in sorted(a2.items(), key=lambda x: x[0])]
consistent &= self._compare_arg_caches(a1_items, a2_items)
@ -374,11 +370,11 @@ class ChunkSizeTuner:
def tune_chunk_size(
self,
representative_fn: Callable,
args: Tuple[Any],
args: tuple,
min_chunk_size: int,
) -> int:
consistent = True
arg_data = tree_map(lambda a: a.shape if type(a) is torch.Tensor else a, args, object)
arg_data: tuple = tree_map(lambda a: a.shape if isinstance(a, torch.Tensor) else a, args, object)
if self.cached_arg_data is not None:
# If args have changed shape/value, we need to re-tune
assert len(self.cached_arg_data) == len(arg_data)
@ -395,4 +391,6 @@ class ChunkSizeTuner:
)
self.cached_arg_data = arg_data
assert self.cached_chunk_size is not None
return self.cached_chunk_size

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import numpy as np
import torch
@ -20,39 +22,39 @@ from . import residue_constants as rc
from .tensor_utils import tensor_tree_map, tree_map
def make_atom14_masks(protein):
def make_atom14_masks(protein: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Construct denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = []
restype_atom37_to_atom14 = []
restype_atom14_mask = []
restype_atom14_to_atom37_list = []
restype_atom37_to_atom14_list = []
restype_atom14_mask_list = []
for rt in rc.restypes:
atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
restype_atom14_to_atom37.append([(rc.atom_order[name] if name else 0) for name in atom_names])
restype_atom14_to_atom37_list.append([(rc.atom_order[name] if name else 0) for name in atom_names])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append(
restype_atom37_to_atom14_list.append(
[(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) for name in rc.atom_types]
)
restype_atom14_mask.append([(1.0 if name else 0.0) for name in atom_names])
restype_atom14_mask_list.append([(1.0 if name else 0.0) for name in atom_names])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.0] * 14)
restype_atom14_to_atom37_list.append([0] * 14)
restype_atom37_to_atom14_list.append([0] * 37)
restype_atom14_mask_list.append([0.0] * 14)
restype_atom14_to_atom37 = torch.tensor(
restype_atom14_to_atom37,
restype_atom14_to_atom37_list,
dtype=torch.int32,
device=protein["aatype"].device,
)
restype_atom37_to_atom14 = torch.tensor(
restype_atom37_to_atom14,
restype_atom37_to_atom14_list,
dtype=torch.int32,
device=protein["aatype"].device,
)
restype_atom14_mask = torch.tensor(
restype_atom14_mask,
restype_atom14_mask_list,
dtype=torch.float32,
device=protein["aatype"].device,
)
@ -85,8 +87,7 @@ def make_atom14_masks(protein):
return protein
def make_atom14_masks_np(batch):
def make_atom14_masks_np(batch: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
batch = tree_map(lambda n: torch.tensor(n, device=batch["aatype"].device), batch, np.ndarray)
out = make_atom14_masks(batch)
out = tensor_tree_map(lambda t: np.array(t), out)
out = tensor_tree_map(lambda t: np.array(t), make_atom14_masks(batch))
return out

View File

@ -13,14 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Tuple, overload
import torch
import torch.nn as nn
import torch.types
from torch import nn
from . import residue_constants as rc
from .rigid_utils import Rigid, Rotation
from .tensor_utils import batched_gather
@overload
def pseudo_beta_fn(aatype: torch.Tensor, all_atom_positions: torch.Tensor, all_atom_masks: None) -> torch.Tensor:
...
@overload
def pseudo_beta_fn(
aatype: torch.Tensor, all_atom_positions: torch.Tensor, all_atom_masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
...
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
is_gly = aatype == rc.restype_order["G"]
ca_idx = rc.atom_order["CA"]
@ -42,7 +57,7 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
return pseudo_beta
def atom14_to_atom37(atom14, batch):
def atom14_to_atom37(atom14: torch.Tensor, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
atom37_data = batched_gather(
atom14,
batch["residx_atom37_to_atom14"],
@ -55,7 +70,7 @@ def atom14_to_atom37(atom14, batch):
return atom37_data
def build_template_angle_feat(template_feats):
def build_template_angle_feat(template_feats: Dict[str, torch.Tensor]) -> torch.Tensor:
template_aatype = template_feats["template_aatype"]
torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"]
@ -73,7 +88,15 @@ def build_template_angle_feat(template_feats):
return template_angle_feat
def build_template_pair_feat(batch, min_bin, max_bin, no_bins, use_unit_vector=False, eps=1e-20, inf=1e8):
def build_template_pair_feat(
batch: Dict[str, torch.Tensor],
min_bin: torch.types.Number,
max_bin: torch.types.Number,
no_bins: int,
use_unit_vector: bool = False,
eps: float = 1e-20,
inf: float = 1e8,
) -> torch.Tensor:
template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
@ -86,7 +109,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, use_unit_vector=F
to_concat = [dgram, template_mask_2d[..., None]]
aatype_one_hot = nn.functional.one_hot(
aatype_one_hot: torch.LongTensor = nn.functional.one_hot(
batch["template_aatype"],
rc.restype_num + 2,
)
@ -126,8 +149,8 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, use_unit_vector=F
return act
def build_extra_msa_feat(batch):
msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23)
def build_extra_msa_feat(batch: Dict[str, torch.Tensor]) -> torch.Tensor:
msa_1hot: torch.LongTensor = nn.functional.one_hot(batch["extra_msa"], 23)
msa_feat = [
msa_1hot,
batch["extra_has_deletion"].unsqueeze(-1),
@ -141,7 +164,7 @@ def torsion_angles_to_frames(
alpha: torch.Tensor,
aatype: torch.Tensor,
rrgdf: torch.Tensor,
):
) -> Rigid:
# [*, N, 8, 4, 4]
default_4x4 = rrgdf[aatype, ...]
@ -172,9 +195,7 @@ def torsion_angles_to_frames(
all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha
all_rots = Rigid(Rotation(rot_mats=all_rots), None)
all_frames = default_r.compose(all_rots)
all_frames = default_r.compose(Rigid(Rotation(rot_mats=all_rots), None))
chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6]
@ -203,22 +224,22 @@ def torsion_angles_to_frames(
def frames_and_literature_positions_to_atom14_pos(
r: Rigid,
aatype: torch.Tensor,
default_frames,
group_idx,
atom_mask,
lit_positions,
):
default_frames: torch.Tensor,
group_idx: torch.Tensor,
atom_mask: torch.Tensor,
lit_positions: torch.Tensor,
) -> torch.Tensor:
# [*, N, 14]
group_mask = group_idx[aatype, ...]
# [*, N, 14, 8]
group_mask = nn.functional.one_hot(
group_mask_one_hot: torch.LongTensor = nn.functional.one_hot(
group_mask,
num_classes=default_frames.shape[-3],
)
# [*, N, 14, 8]
t_atoms_to_global = r[..., None, :] * group_mask
t_atoms_to_global = r[..., None, :] * group_mask_one_hot
# [*, N, 14]
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1))

View File

@ -18,7 +18,7 @@ from typing import Dict, Optional, Tuple
import torch
def _calculate_bin_centers(boundaries: torch.Tensor):
def _calculate_bin_centers(boundaries: torch.Tensor) -> torch.Tensor:
step = boundaries[1] - boundaries[0]
bin_centers = boundaries + step / 2
bin_centers = torch.cat([bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0)

View File

@ -17,7 +17,7 @@
import dataclasses
import re
import string
from typing import Any, Mapping, Optional, Sequence
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple
import numpy as np
@ -69,10 +69,10 @@ class Protein:
def from_proteinnet_string(proteinnet_str: str) -> Protein:
tag_re = r"(\[[A-Z]+\]\n)"
tags = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0]
groups = zip(tags[0::2], [l.split("\n") for l in tags[1::2]])
tags: List[str] = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0]
groups: Iterator[Tuple[str, List[str]]] = zip(tags[0::2], [l.split("\n") for l in tags[1::2]])
atoms = ["N", "CA", "C"]
atoms: List[str] = ["N", "CA", "C"]
aatype = None
atom_positions = None
atom_mask = None
@ -81,12 +81,12 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
seq = g[1][0].strip()
for i in range(len(seq)):
if seq[i] not in residue_constants.restypes:
seq[i] = "X"
seq[i] = "X" # FIXME: strings are immutable
aatype = np.array(
[residue_constants.restype_order.get(res_symbol, residue_constants.restype_num) for res_symbol in seq]
)
elif "[TERTIARY]" == g[0]:
tertiary = []
tertiary: List[List[float]] = []
for axis in range(3):
tertiary.append(list(map(float, g[1][axis].split())))
tertiary_np = np.array(tertiary)
@ -106,6 +106,8 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
atom_mask[:, residue_constants.atom_order[atom]] = 1
atom_mask *= mask[..., None]
assert aatype is not None
return Protein(
atom_positions=atom_positions,
atom_mask=atom_mask,
@ -115,8 +117,8 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
)
def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
pdb_headers = []
def get_pdb_headers(prot: Protein, chain_id: int = 0) -> List[str]:
pdb_headers: List[str] = []
remark = prot.remark
if remark is not None:
@ -124,7 +126,7 @@ def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
parents = prot.parents
parents_chain_index = prot.parents_chain_index
if parents_chain_index is not None:
if parents is not None and parents_chain_index is not None:
parents = [p for i, p in zip(parents_chain_index, parents) if i == chain_id]
if parents is None or len(parents) == 0:
@ -139,18 +141,18 @@ def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
"""Add pdb headers to an existing PDB string. Useful during multi-chain
recycling
"""
out_pdb_lines = []
out_pdb_lines: List[str] = []
lines = pdb_str.split("\n")
remark = prot.remark
if remark is not None:
out_pdb_lines.append(f"REMARK {remark}")
parents_per_chain = None
parents_per_chain: List[List[str]]
if prot.parents is not None and len(prot.parents) > 0:
parents_per_chain = []
if prot.parents_chain_index is not None:
parent_dict = {}
parent_dict: Dict[str, List[str]] = {}
for p, i in zip(prot.parents, prot.parents_chain_index):
parent_dict.setdefault(str(i), [])
parent_dict[str(i)].append(p)
@ -160,11 +162,11 @@ def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
chain_parents = parent_dict.get(str(i), ["N/A"])
parents_per_chain.append(chain_parents)
else:
parents_per_chain.append(prot.parents)
parents_per_chain.append(list(prot.parents))
else:
parents_per_chain = [["N/A"]]
def make_parent_line(p):
def make_parent_line(p: Sequence[str]) -> str:
return f"PARENT {' '.join(p)}"
out_pdb_lines.append(make_parent_line(parents_per_chain[0]))
@ -196,12 +198,12 @@ def to_pdb(prot: Protein) -> str:
"""
restypes = residue_constants.restypes + ["X"]
def res_1to3(r):
def res_1to3(r: int) -> str:
return residue_constants.restype_1to3.get(restypes[r], "UNK")
atom_types = residue_constants.atom_types
pdb_lines = []
pdb_lines: List[str] = []
atom_mask = prot.atom_mask
aatype = prot.aatype
@ -221,6 +223,7 @@ def to_pdb(prot: Protein) -> str:
atom_index = 1
prev_chain_index = 0
chain_tags = string.ascii_uppercase
chain_tag = None
# Add all atom sites.
for i in range(n):
res_name_3 = res_1to3(aatype[i])
@ -313,15 +316,12 @@ def from_prediction(
Returns:
A protein instance.
"""
if b_factors is None:
b_factors = np.zeros_like(result["final_atom_mask"])
return Protein(
aatype=features["aatype"],
atom_positions=result["final_atom_positions"],
atom_mask=result["final_atom_mask"],
residue_index=features["residue_index"] + 1,
b_factors=b_factors,
b_factors=b_factors if b_factors is not None else np.zeros_like(result["final_atom_mask"]),
chain_index=chain_index,
remark=remark,
parents=parents,

View File

@ -16,7 +16,7 @@
from __future__ import annotations
from functools import lru_cache
from typing import Any, Callable, Optional, Sequence, Tuple
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
import numpy as np
import torch
@ -33,7 +33,7 @@ def rot_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
The product ab
"""
def row_mul(i):
def row_mul(i: int) -> torch.Tensor:
return torch.stack(
[
a[..., i, 0] * b[..., 0, 0] + a[..., i, 1] * b[..., 1, 0] + a[..., i, 2] * b[..., 2, 0],
@ -76,7 +76,7 @@ def rot_vec_mul(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
@lru_cache(maxsize=None)
def identity_rot_mats(
batch_dims: Tuple[int],
batch_dims: Tuple[int, ...],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
@ -91,7 +91,7 @@ def identity_rot_mats(
@lru_cache(maxsize=None)
def identity_trans(
batch_dims: Tuple[int],
batch_dims: Tuple[int, ...],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
@ -102,7 +102,7 @@ def identity_trans(
@lru_cache(maxsize=None)
def identity_quats(
batch_dims: Tuple[int],
batch_dims: Tuple[int, ...],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
@ -115,15 +115,14 @@ def identity_quats(
return quat
_quat_elements = ["a", "b", "c", "d"]
_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
_quat_elements: List[str] = ["a", "b", "c", "d"]
_qtr_keys: List[str] = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
_qtr_ind_dict: Dict[str, int] = {key: ind for ind, key in enumerate(_qtr_keys)}
def _to_mat(pairs):
def _to_mat(pairs: List[Tuple[str, int]]) -> np.ndarray:
mat = np.zeros((4, 4))
for pair in pairs:
key, value = pair
for key, value in pairs:
ind = _qtr_ind_dict[key]
mat[ind // 4][ind % 4] = value
@ -165,14 +164,11 @@ def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
return torch.sum(quat, dim=(-3, -4))
def rot_to_quat(
rot: torch.Tensor,
):
def rot_to_quat(rot: torch.Tensor) -> torch.Tensor:
if rot.shape[-2:] != (3, 3):
raise ValueError("Input rotation is incorrectly shaped")
rot = [[rot[..., i, j] for j in range(3)] for i in range(3)]
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = [[rot[..., i, j] for j in range(3)] for i in range(3)]
k = [
[
@ -201,9 +197,7 @@ def rot_to_quat(
],
]
k = (1.0 / 3.0) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)
_, vectors = torch.linalg.eigh(k)
_, vectors = torch.linalg.eigh((1.0 / 3.0) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2))
return vectors[..., -1]
@ -218,7 +212,7 @@ _QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], [1, 0, 0,
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
_CACHED_QUATS = {
_CACHED_QUATS: Dict[str, np.ndarray] = {
"_QTR_MAT": _QTR_MAT,
"_QUAT_MULTIPLY": _QUAT_MULTIPLY,
"_QUAT_MULTIPLY_BY_VEC": _QUAT_MULTIPLY_BY_VEC,
@ -226,29 +220,29 @@ _CACHED_QUATS = {
@lru_cache(maxsize=None)
def _get_quat(quat_key, dtype, device):
def _get_quat(quat_key: str, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device)
def quat_multiply(quat1, quat2):
def quat_multiply(quat1: torch.Tensor, quat2: torch.Tensor) -> torch.Tensor:
"""Multiply a quaternion by another quaternion."""
mat = _get_quat("_QUAT_MULTIPLY", dtype=quat1.dtype, device=quat1.device)
reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
return torch.sum(reshaped_mat * quat1[..., :, None, None] * quat2[..., None, :, None], dim=(-3, -2))
def quat_multiply_by_vec(quat, vec):
def quat_multiply_by_vec(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
"""Multiply a quaternion by a pure-vector quaternion."""
mat = _get_quat("_QUAT_MULTIPLY_BY_VEC", dtype=quat.dtype, device=quat.device)
reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
return torch.sum(reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], dim=(-3, -2))
def invert_rot_mat(rot_mat: torch.Tensor):
def invert_rot_mat(rot_mat: torch.Tensor) -> torch.Tensor:
return rot_mat.transpose(-1, -2)
def invert_quat(quat: torch.Tensor):
def invert_quat(quat: torch.Tensor) -> torch.Tensor:
quat_prime = quat.clone()
quat_prime[..., 1:] *= -1
inv = quat_prime / torch.sum(quat**2, dim=-1, keepdim=True)
@ -361,10 +355,7 @@ class Rotation:
else:
raise ValueError("Both rotations are None")
def __mul__(
self,
right: torch.Tensor,
) -> Rotation:
def __mul__(self, right: torch.Tensor) -> Rotation:
"""
Pointwise left multiplication of the rotation with a tensor. Can be used to e.g. mask the Rotation.
@ -386,10 +377,7 @@ class Rotation:
else:
raise ValueError("Both rotations are None")
def __rmul__(
self,
left: torch.Tensor,
) -> Rotation:
def __rmul__(self, left: torch.Tensor) -> Rotation:
"""
Reverse pointwise multiplication of the rotation with a tensor.
@ -413,13 +401,12 @@ class Rotation:
Returns:
The virtual shape of the rotation object
"""
s = None
if self._quats is not None:
s = self._quats.shape[:-1]
if self._rot_mats is not None:
return self._rot_mats.shape[:-2]
elif self._quats is not None:
return self._quats.shape[:-1]
else:
s = self._rot_mats.shape[:-2]
return s
raise ValueError("Both rotations are None")
@property
def dtype(self) -> torch.dtype:
@ -473,14 +460,12 @@ class Rotation:
Returns:
The rotation as a rotation matrix tensor
"""
rot_mats = self._rot_mats
if rot_mats is None:
if self._quats is None:
raise ValueError("Both rotations are None")
else:
rot_mats = quat_to_rot(self._quats)
return rot_mats
if self._rot_mats is not None:
return self._rot_mats
elif self._quats is not None:
return quat_to_rot(self._quats)
else:
raise ValueError("Both rotations are None")
def get_quats(self) -> torch.Tensor:
"""
@ -491,14 +476,12 @@ class Rotation:
Returns:
The rotation as a quaternion tensor.
"""
quats = self._quats
if quats is None:
if self._rot_mats is None:
raise ValueError("Both rotations are None")
else:
quats = rot_to_quat(self._rot_mats)
return quats
if self._rot_mats is not None:
return rot_to_quat(self._rot_mats)
elif self._quats is not None:
return self._quats
else:
raise ValueError("Both rotations are None")
def get_cur_rot(self) -> torch.Tensor:
"""
@ -618,10 +601,7 @@ class Rotation:
# "Tensor" stuff
def unsqueeze(
self,
dim: int,
) -> Rigid:
def unsqueeze(self, dim: int) -> Rotation:
"""
Analogous to torch.unsqueeze. The dimension is relative to the shape of the Rotation object.
@ -643,10 +623,7 @@ class Rotation:
raise ValueError("Both rotations are None")
@staticmethod
def cat(
rs: Sequence[Rotation],
dim: int,
) -> Rigid:
def cat(rs: Sequence[Rotation], dim: int) -> Rotation:
"""
Concatenates rotations along one of the batch dimensions. Analogous to torch.cat().
@ -661,12 +638,14 @@ class Rotation:
Returns:
A concatenated Rotation object in rotation matrix format
"""
rot_mats = [r.get_rot_mats() for r in rs]
rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)
rot_mats = torch.cat(
[r.get_rot_mats() for r in rs],
dim=dim if dim >= 0 else dim - 2,
)
return Rotation(rot_mats=rot_mats, quats=None)
def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rotation:
def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rotation:
"""
Apply a Tensor -> Tensor function to underlying rotation tensors, mapping over the rotation dimension(s). Can
be used e.g. to sum out a one-hot batch dimension.
@ -754,11 +733,7 @@ class Rigid:
dimensions of its component parts.
"""
def __init__(
self,
rots: Optional[Rotation],
trans: Optional[torch.Tensor],
):
def __init__(self, rots: Optional[Rotation], trans: Optional[torch.Tensor]):
"""
Args:
rots: A [*, 3, 3] rotation tensor
@ -795,6 +770,9 @@ class Rigid:
requires_grad,
)
assert rots is not None
assert trans is not None
if (rots.shape != trans.shape[:-1]) or (rots.device != trans.device):
raise ValueError("Rots and trans incompatible")
@ -806,7 +784,7 @@ class Rigid:
@staticmethod
def identity(
shape: Tuple[int],
shape: Tuple[int, ...],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
@ -832,10 +810,7 @@ class Rigid:
identity_trans(shape, dtype, device, requires_grad),
)
def __getitem__(
self,
index: Any,
) -> Rigid:
def __getitem__(self, index: Any) -> Rigid:
"""
Indexes the affine transformation with PyTorch-style indices. The index is applied to the shared dimensions of
both the rotation and the translation.
@ -860,10 +835,7 @@ class Rigid:
self._trans[index + (slice(None),)],
)
def __mul__(
self,
right: torch.Tensor,
) -> Rigid:
def __mul__(self, right: torch.Tensor) -> Rigid:
"""
Pointwise left multiplication of the transformation with a tensor. Can be used to e.g. mask the Rigid.
@ -881,10 +853,7 @@ class Rigid:
return Rigid(new_rots, new_trans)
def __rmul__(
self,
left: torch.Tensor,
) -> Rigid:
def __rmul__(self, left: torch.Tensor) -> Rigid:
"""
Reverse pointwise multiplication of the transformation with a tensor.
@ -904,8 +873,7 @@ class Rigid:
Returns:
The shape of the transformation
"""
s = self._trans.shape[:-1]
return s
return self._trans.shape[:-1]
@property
def device(self) -> torch.device:
@ -935,10 +903,7 @@ class Rigid:
"""
return self._trans
def compose_q_update_vec(
self,
q_update_vec: torch.Tensor,
) -> Rigid:
def compose_q_update_vec(self, q_update_vec: torch.Tensor) -> Rigid:
"""
Composes the transformation with a quaternion update vector of shape [*, 6], where the final 6 columns
represent the x, y, and z values of a quaternion of form (1, x, y, z) followed by a 3D translation.
@ -956,10 +921,7 @@ class Rigid:
return Rigid(new_rots, new_translation)
def compose(
self,
r: Rigid,
) -> Rigid:
def compose(self, r: Rigid) -> Rigid:
"""
Composes the current rigid object with another.
@ -973,10 +935,7 @@ class Rigid:
new_trans = self._rots.apply(r._trans) + self._trans
return Rigid(new_rot, new_trans)
def apply(
self,
pts: torch.Tensor,
) -> torch.Tensor:
def apply(self, pts: torch.Tensor) -> torch.Tensor:
"""
Applies the transformation to a coordinate tensor.
@ -1012,7 +971,7 @@ class Rigid:
return Rigid(rot_inv, -1 * trn_inv)
def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid:
def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rigid:
"""
Apply a Tensor -> Tensor function to underlying translation and rotation tensors, mapping over the
translation/rotation dimensions respectively.
@ -1074,10 +1033,7 @@ class Rigid:
return tensor
@staticmethod
def from_tensor_7(
t: torch.Tensor,
normalize_quats: bool = False,
) -> Rigid:
def from_tensor_7(t: torch.Tensor, normalize_quats: bool = False) -> Rigid:
if t.shape[-1] != 7:
raise ValueError("Incorrectly shaped input tensor")
@ -1102,18 +1058,18 @@ class Rigid:
Returns:
A transformation object of shape [*]
"""
p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
origin = torch.unbind(origin, dim=-1)
p_xy_plane = torch.unbind(p_xy_plane, dim=-1)
p_neg_x_axis_unbound = torch.unbind(p_neg_x_axis, dim=-1)
origin_unbound = torch.unbind(origin, dim=-1)
p_xy_plane_unbound = torch.unbind(p_xy_plane, dim=-1)
e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]
e0 = [c1 - c2 for c1, c2 in zip(origin_unbound, p_neg_x_axis_unbound)]
e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane_unbound, origin_unbound)]
denom = torch.sqrt(sum((c * c for c in e0)) + eps)
denom = torch.sqrt(sum(c * c for c in e0) + eps * torch.ones_like(e0[0]))
e0 = [c / denom for c in e0]
dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
denom = torch.sqrt(sum((c * c for c in e1)) + eps)
denom = torch.sqrt(sum((c * c for c in e1)) + eps * torch.ones_like(e1[0]))
e1 = [c / denom for c in e1]
e2 = [
e0[1] * e1[2] - e0[2] * e1[1],
@ -1126,12 +1082,9 @@ class Rigid:
rot_obj = Rotation(rot_mats=rots, quats=None)
return Rigid(rot_obj, torch.stack(origin, dim=-1))
return Rigid(rot_obj, torch.stack(origin_unbound, dim=-1))
def unsqueeze(
self,
dim: int,
) -> Rigid:
def unsqueeze(self, dim: int) -> Rigid:
"""
Analogous to torch.unsqueeze. The dimension is relative to the shared dimensions of the rotation/translation.
@ -1148,10 +1101,7 @@ class Rigid:
return Rigid(rots, trans)
@staticmethod
def cat(
ts: Sequence[Rigid],
dim: int,
) -> Rigid:
def cat(ts: Sequence[Rigid], dim: int) -> Rigid:
"""
Concatenates transformations along a new dimension.
@ -1168,7 +1118,7 @@ class Rigid:
return Rigid(rots, trans)
def apply_rot_fn(self, fn: Callable[Rotation, Rotation]) -> Rigid:
def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> Rigid:
"""
Applies a Rotation -> Rotation function to the stored rotation object.
@ -1179,7 +1129,7 @@ class Rigid:
"""
return Rigid(fn(self._rots), self._trans)
def apply_trans_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid:
def apply_trans_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rigid:
"""
Applies a Tensor -> Tensor function to the stored translation.
@ -1213,7 +1163,9 @@ class Rigid:
return self.apply_rot_fn(lambda r: r.detach())
@staticmethod
def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
def make_transform_from_reference(
n_xyz: torch.Tensor, ca_xyz: torch.Tensor, c_xyz: torch.Tensor, eps: float = 1e-20
) -> Rigid:
"""
Returns a transformation object from reference coordinates.

View File

@ -14,13 +14,14 @@
# limitations under the License.
from functools import partial
from typing import List
from typing import Any, Callable, Dict, List, Type, TypeVar, Union, overload
import torch
import torch.nn as nn
import torch.types
def add(m1, m2, inplace):
def add(m1: torch.Tensor, m2: torch.Tensor, inplace: bool) -> torch.Tensor:
# The first operation in a checkpoint can't be in-place, but it's
# nice to have in-place addition during inference. Thus...
if not inplace:
@ -31,33 +32,35 @@ def add(m1, m2, inplace):
return m1
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
def permute_final_dims(tensor: torch.Tensor, inds: List[int]) -> torch.Tensor:
zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds])
def flatten_final_dims(t: torch.Tensor, no_dims: int):
def flatten_final_dims(t: torch.Tensor, no_dims: int) -> torch.Tensor:
return t.reshape(t.shape[:-no_dims] + (-1,))
def masked_mean(mask, value, dim, eps=1e-4):
def masked_mean(mask: torch.Tensor, value: torch.Tensor, dim: int, eps: float = 1e-4) -> torch.Tensor:
mask = mask.expand(*value.shape)
return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
def pts_to_distogram(
pts: torch.Tensor, min_bin: torch.types.Number = 2.3125, max_bin: torch.types.Number = 21.6875, no_bins: int = 64
) -> torch.Tensor:
boundaries = torch.linspace(min_bin, max_bin, no_bins - 1, device=pts.device)
dists = torch.sqrt(torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1))
return torch.bucketize(dists, boundaries)
def dict_multimap(fn, dicts):
def dict_multimap(fn: Callable[[list], Any], dicts: List[dict]) -> dict:
first = dicts[0]
new_dict = {}
for k, v in first.items():
all_v = [d[k] for d in dicts]
if type(v) is dict:
if isinstance(v, dict):
new_dict[k] = dict_multimap(fn, all_v)
else:
new_dict[k] = fn(all_v)
@ -65,21 +68,21 @@ def dict_multimap(fn, dicts):
return new_dict
def one_hot(x, v_bins):
def one_hot(x: torch.Tensor, v_bins: torch.Tensor) -> torch.Tensor:
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
diffs = x[..., None] - reshaped_bins
am = torch.argmin(torch.abs(diffs), dim=-1)
return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
def batched_gather(data, inds, dim=0, no_batch_dims=0):
ranges = []
def batched_gather(data: torch.Tensor, inds: torch.Tensor, dim: int = 0, no_batch_dims: int = 0) -> torch.Tensor:
ranges: List[Union[slice, torch.Tensor]] = []
for i, s in enumerate(data.shape[:no_batch_dims]):
r = torch.arange(s)
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
ranges.append(r)
remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
remaining_dims: List[Union[slice, torch.Tensor]] = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
ranges.extend(remaining_dims)
# Matt note: Editing this to get around the behaviour of using a list as an array index changing
@ -87,11 +90,16 @@ def batched_gather(data, inds, dim=0, no_batch_dims=0):
return data[tuple(ranges)]
T = TypeVar("T")
# With tree_map, a poor man's JAX tree_map
def dict_map(fn, dic, leaf_type):
new_dict = {}
def dict_map(
fn: Callable[[T], Any], dic: Dict[Any, Union[dict, list, tuple, T]], leaf_type: Type[T]
) -> Dict[Any, Union[dict, list, tuple, Any]]:
new_dict: Dict[Any, Union[dict, list, tuple, Any]] = {}
for k, v in dic.items():
if type(v) is dict:
if isinstance(v, dict):
new_dict[k] = dict_map(fn, v, leaf_type)
else:
new_dict[k] = tree_map(fn, v, leaf_type)
@ -99,13 +107,33 @@ def dict_map(fn, dic, leaf_type):
return new_dict
@overload
def tree_map(fn: Callable[[T], Any], tree: T, leaf_type: Type[T]) -> Any:
...
@overload
def tree_map(fn: Callable[[T], Any], tree: dict, leaf_type: Type[T]) -> dict:
...
@overload
def tree_map(fn: Callable[[T], Any], tree: list, leaf_type: Type[T]) -> list:
...
@overload
def tree_map(fn: Callable[[T], Any], tree: tuple, leaf_type: Type[T]) -> tuple:
...
def tree_map(fn, tree, leaf_type):
if isinstance(tree, dict):
return dict_map(fn, tree, leaf_type)
elif isinstance(tree, list):
return [tree_map(fn, x, leaf_type) for x in tree]
elif isinstance(tree, tuple):
return tuple([tree_map(fn, x, leaf_type) for x in tree])
return tuple(tree_map(fn, x, leaf_type) for x in tree)
elif isinstance(tree, leaf_type):
return fn(tree)
else: